Skip to content

Commit 4e51062

Browse files
AnandInguvayeandy
authored andcommitted
Merge pull request apache#21781 from Sklearn Mnist example and IT test
* sklearn example and IT test * Change the example name * Refactor sklearn example * Refactor and add assertions to the sklearn test * Fixup import order * fixup: help and name * Add gradle task for sklearn IT tests * fixup lint * Update sdks/python/test-suites/direct/common.gradle Co-authored-by: Andy Ye <andyye333@gmail.com> * Change sklearn IT test marker * Uncomment * Apply suggestions from code review Co-authored-by: Andy Ye <andyye333@gmail.com> Co-authored-by: Andy Ye <andyye333@gmail.com>
1 parent b989efe commit 4e51062

File tree

4 files changed

+214
-3
lines changed

4 files changed

+214
-3
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""A pipeline that uses RunInference API to classify MNIST data.
19+
20+
This pipeline takes a text file in which data is comma separated ints. The first
21+
column would be the true label and the rest would be the pixel values. The data
22+
is processed and then a model trained on the MNIST data would be used to perform
23+
the inference. The pipeline writes the prediction to an output file in which
24+
users can then compare against the true label.
25+
"""
26+
27+
import argparse
28+
from typing import Iterable
29+
from typing import List
30+
from typing import Tuple
31+
32+
import apache_beam as beam
33+
from apache_beam.ml.inference.base import KeyedModelHandler
34+
from apache_beam.ml.inference.base import PredictionResult
35+
from apache_beam.ml.inference.base import RunInference
36+
from apache_beam.ml.inference.sklearn_inference import ModelFileType
37+
from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy
38+
from apache_beam.options.pipeline_options import PipelineOptions
39+
from apache_beam.options.pipeline_options import SetupOptions
40+
41+
42+
def process_input(row: str) -> Tuple[int, List[int]]:
43+
data = row.split(',')
44+
label, pixels = int(data[0]), data[1:]
45+
pixels = [int(pixel) for pixel in pixels]
46+
return label, pixels
47+
48+
49+
class PostProcessor(beam.DoFn):
50+
"""Process the PredictionResult to get the predicted label.
51+
Returns a comma separated string with true label and predicted label.
52+
"""
53+
def process(self, element: Tuple[int, PredictionResult]) -> Iterable[str]:
54+
label, prediction_result = element
55+
prediction = prediction_result.inference
56+
yield '{},{}'.format(label, prediction)
57+
58+
59+
def parse_known_args(argv):
60+
"""Parses args for the workflow."""
61+
parser = argparse.ArgumentParser()
62+
parser.add_argument(
63+
'--input_file',
64+
dest='input',
65+
required=True,
66+
help='text file with comma separated int values.')
67+
parser.add_argument(
68+
'--output',
69+
dest='output',
70+
required=True,
71+
help='Path to save output predictions.')
72+
parser.add_argument(
73+
'--model_path',
74+
dest='model_path',
75+
required=True,
76+
help='Path to load the Sklearn model for Inference.')
77+
return parser.parse_known_args(argv)
78+
79+
80+
def run(argv=None, save_main_session=True):
81+
"""Entry point. Defines and runs the pipeline."""
82+
known_args, pipeline_args = parse_known_args(argv)
83+
pipeline_options = PipelineOptions(pipeline_args)
84+
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
85+
86+
# In this example we pass keyed inputs to RunInference transform.
87+
# Therefore, we use KeyedModelHandler wrapper over SklearnModelHandlerNumpy.
88+
model_loader = KeyedModelHandler(
89+
SklearnModelHandlerNumpy(
90+
model_file_type=ModelFileType.PICKLE,
91+
model_uri=known_args.model_path))
92+
93+
with beam.Pipeline(options=pipeline_options) as p:
94+
label_pixel_tuple = (
95+
p
96+
| "ReadFromInput" >> beam.io.ReadFromText(
97+
known_args.input, skip_header_lines=1)
98+
| "PreProcessInputs" >> beam.Map(process_input))
99+
100+
predictions = (
101+
label_pixel_tuple
102+
| "RunInference" >> RunInference(model_loader)
103+
| "PostProcessOutputs" >> beam.ParDo(PostProcessor()))
104+
105+
_ = predictions | "WriteOutput" >> beam.io.WriteToText(
106+
known_args.output,
107+
shard_name_template='',
108+
append_trailing_newlines=True)
109+
110+
111+
if __name__ == '__main__':
112+
run()
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""End-to-End test for Sklearn Inference"""
19+
20+
import logging
21+
import unittest
22+
import uuid
23+
24+
import pytest
25+
26+
from apache_beam.examples.inference import sklearn_mnist_classification
27+
from apache_beam.io.filesystems import FileSystems
28+
from apache_beam.testing.test_pipeline import TestPipeline
29+
30+
31+
def process_outputs(filepath):
32+
with FileSystems().open(filepath) as f:
33+
lines = f.readlines()
34+
lines = [l.decode('utf-8').strip('\n') for l in lines]
35+
return lines
36+
37+
38+
@pytest.mark.skip
39+
@pytest.mark.uses_sklearn
40+
@pytest.mark.it_postcommit
41+
class SklearnInference(unittest.TestCase):
42+
def test_sklearn_mnist_classification(self):
43+
test_pipeline = TestPipeline(is_integration_test=False)
44+
input_file = 'gs://apache-beam-ml/testing/inputs/it_mnist_data.csv'
45+
output_file_dir = 'gs://temp-storage-for-end-to-end-tests'
46+
output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])
47+
model_path = 'gs://apache-beam-ml/models/mnist_model_svm.pickle'
48+
extra_opts = {
49+
'input': input_file,
50+
'output': output_file,
51+
'model_path': model_path,
52+
}
53+
sklearn_mnist_classification.run(
54+
test_pipeline.get_full_options_as_args(**extra_opts),
55+
save_main_session=False)
56+
self.assertEqual(FileSystems().exists(output_file), True)
57+
58+
expected_output_filepath = 'gs://apache-beam-ml/testing/expected_outputs/test_sklearn_mnist_classification_actuals.txt' # pylint: disable=line-too-long
59+
expected_outputs = process_outputs(expected_output_filepath)
60+
61+
predicted_outputs = process_outputs(output_file)
62+
self.assertEqual(len(expected_outputs), len(predicted_outputs))
63+
64+
predictions_dict = {}
65+
for i in range(len(predicted_outputs)):
66+
true_label, prediction = predicted_outputs[i].split(',')
67+
predictions_dict[true_label] = prediction
68+
69+
for i in range(len(expected_outputs)):
70+
true_label, expected_prediction = expected_outputs[i].split(',')
71+
self.assertEqual(predictions_dict[true_label], expected_prediction)
72+
73+
74+
if __name__ == '__main__':
75+
logging.getLogger().setLevel(logging.DEBUG)
76+
unittest.main()

sdks/python/pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ markers =
4848
# We run these tests with multiple major pyarrow versions (BEAM-11211)
4949
uses_pyarrow: tests that utilize pyarrow in some way
5050
uses_pytorch: tests that utilize pytorch in some way
51+
uses_sklearn: tests that utilize scikit-learn in some way
5152

5253
# Default timeout intended for unit tests.
5354
# If certain tests need a different value, please see the docs on how to

sdks/python/test-suites/direct/common.gradle

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ tasks.register("hdfsIntegrationTest") {
187187
}
188188

189189
// Pytorch RunInference IT tests
190-
task torchTests {
190+
task torchInferenceTest {
191191
dependsOn 'installGcpTest'
192192
dependsOn ':sdks:python:sdist'
193193
def requirementsFile = "${rootDir}/sdks/python/apache_beam/ml/inference/torch_tests_requirements.txt"
@@ -211,10 +211,32 @@ task torchTests {
211211
args '-c', ". ${envdir}/bin/activate && export FORCE_TORCH_IT=1 && ${runScriptsDir}/run_integration_test.sh $cmdArgs"
212212
}
213213
}
214+
215+
}
216+
// Scikit-learn RunInference IT tests
217+
task sklearnInferenceTest {
218+
dependsOn 'installGcpTest'
219+
dependsOn ':sdks:python:sdist'
220+
doLast {
221+
def testOpts = basicTestOpts
222+
def argMap = [
223+
"test_opts": testOpts,
224+
"suite": "postCommitIT-direct-py${pythonVersionSuffix}",
225+
"collect": "uses_sklearn and it_postcommit" ,
226+
"runner": "TestDirectRunner"
227+
]
228+
def cmdArgs = mapToArgString(argMap)
229+
exec {
230+
executable 'sh'
231+
args '-c', ". ${envdir}/bin/activate && ${runScriptsDir}/run_integration_test.sh $cmdArgs"
232+
}
233+
}
214234
}
215235

216236
// Add all the RunInference framework IT tests to this gradle task that runs on Direct Runner Post commit suite.
217-
// TODO(anandinguva): Add sklearn IT test here
218237
project.tasks.register("inferencePostCommitIT") {
219-
dependsOn = ['torchTests']
238+
dependsOn = [
239+
'torchInferenceTest',
240+
'sklearnInferenceTest'
241+
]
220242
}

0 commit comments

Comments
 (0)