Skip to content

Commit 5954209

Browse files
authored
[BEAM-13984] Implement RunInference for PyTorch (#17196)
* Initial pytorch implementation * Clean up pytorch implementation; Works for single example * Fix for multiple examples in a batch * Fix header and documentation * Add multifeature tests; Add numpy/tensor conversion * Add torch to setup.py * Add ml to tox.ini * Remove numpy checks/conversions; Address PR comments * Remove GPU code and test * Add Filesystems * Add separate pytorch install and tox test * Fix typos in gradle and tox files * Add separate tox tests for pytorch; Remove torch setup install * fix import error * Add unittest main() * Add PredictionResult; Refactor tests * Fix docs; Remove keyed test * Add gcp to tox
1 parent 373c1c9 commit 5954209

File tree

6 files changed

+334
-0
lines changed

6 files changed

+334
-0
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
# pytype: skip-file
19+
20+
from typing import Iterable
21+
from typing import List
22+
23+
import torch
24+
from apache_beam.io.filesystems import FileSystems
25+
from apache_beam.ml.inference.api import PredictionResult
26+
from apache_beam.ml.inference.base import InferenceRunner
27+
from apache_beam.ml.inference.base import ModelLoader
28+
29+
30+
class PytorchInferenceRunner(InferenceRunner):
31+
"""
32+
This class runs Pytorch inferences with the run_inference method. It also has
33+
other methods to get the bytes of a batch of Tensors as well as the namespace
34+
for Pytorch models.
35+
"""
36+
def __init__(self, device: torch.device):
37+
self._device = device
38+
39+
def run_inference(self, batch: List[torch.Tensor],
40+
model: torch.nn.Module) -> Iterable[torch.Tensor]:
41+
"""
42+
Runs inferences on a batch of Tensors and returns an Iterable of
43+
Tensor Predictions.
44+
45+
This method stacks the list of Tensors in a vectorized format to optimize
46+
the inference call.
47+
"""
48+
49+
batch = torch.stack(batch)
50+
if batch.device != self._device:
51+
batch = batch.to(self._device)
52+
predictions = model(batch)
53+
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
54+
55+
def get_num_bytes(self, batch: List[torch.Tensor]) -> int:
56+
"""Returns the number of bytes of data for a batch of Tensors."""
57+
return sum((el.element_size() for tensor in batch for el in tensor))
58+
59+
def get_metrics_namespace(self) -> str:
60+
"""
61+
Returns a namespace for metrics collected by the RunInference transform.
62+
"""
63+
return 'RunInferencePytorch'
64+
65+
66+
class PytorchModelLoader(ModelLoader):
67+
"""Loads a Pytorch Model."""
68+
def __init__(
69+
self,
70+
state_dict_path: str,
71+
model_class: torch.nn.Module,
72+
device: str = 'CPU'):
73+
"""
74+
state_dict_path: path to the saved dictionary of the model state.
75+
model_class: class of the Pytorch model that defines the model structure.
76+
device: the device on which you wish to run the model. If ``device = GPU``
77+
then device will be cuda if it is avaiable. Otherwise, it will be cpu.
78+
79+
See https://pytorch.org/tutorials/beginner/saving_loading_models.html
80+
for details
81+
"""
82+
self._state_dict_path = state_dict_path
83+
if device == 'GPU' and torch.cuda.is_available():
84+
self._device = torch.device('cuda')
85+
else:
86+
self._device = torch.device('cpu')
87+
self._model_class = model_class
88+
self._model_class.to(self._device)
89+
self._inference_runner = PytorchInferenceRunner(device=self._device)
90+
91+
def load_model(self) -> torch.nn.Module:
92+
"""Loads and initializes a Pytorch model for processing."""
93+
model = self._model_class
94+
file = FileSystems.open(self._state_dict_path, 'rb')
95+
model.load_state_dict(torch.load(file))
96+
model.eval()
97+
return model
98+
99+
def get_inference_runner(self) -> InferenceRunner:
100+
"""Returns a Pytorch implementation of InferenceRunner."""
101+
return self._inference_runner
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
# pytype: skip-file
19+
20+
import os
21+
import shutil
22+
import tempfile
23+
import unittest
24+
from collections import OrderedDict
25+
26+
import numpy as np
27+
import pytest
28+
29+
import apache_beam as beam
30+
from apache_beam.testing.test_pipeline import TestPipeline
31+
from apache_beam.testing.util import assert_that
32+
from apache_beam.testing.util import equal_to
33+
34+
# Protect against environments where pytorch library is not available.
35+
# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
36+
try:
37+
import torch
38+
from apache_beam.ml.inference.api import PredictionResult
39+
from apache_beam.ml.inference.base import RunInference
40+
from apache_beam.ml.inference.pytorch import PytorchInferenceRunner
41+
from apache_beam.ml.inference.pytorch import PytorchModelLoader
42+
except ImportError:
43+
raise unittest.SkipTest('PyTorch dependencies are not installed')
44+
45+
46+
def _compare_prediction_result(a, b):
47+
return (
48+
torch.equal(a.inference, b.inference) and
49+
torch.equal(a.example, b.example))
50+
51+
52+
class PytorchLinearRegression(torch.nn.Module):
53+
def __init__(self, input_dim, output_dim):
54+
super().__init__()
55+
self.linear = torch.nn.Linear(input_dim, output_dim)
56+
57+
def forward(self, x):
58+
out = self.linear(x)
59+
return out
60+
61+
62+
@pytest.mark.uses_pytorch
63+
class PytorchRunInferenceTest(unittest.TestCase):
64+
def setUp(self):
65+
self.tmpdir = tempfile.mkdtemp()
66+
67+
def tearDown(self):
68+
shutil.rmtree(self.tmpdir)
69+
70+
def test_inference_runner_single_tensor_feature(self):
71+
examples = [
72+
torch.from_numpy(np.array([1], dtype="float32")),
73+
torch.from_numpy(np.array([5], dtype="float32")),
74+
torch.from_numpy(np.array([-3], dtype="float32")),
75+
torch.from_numpy(np.array([10.0], dtype="float32")),
76+
]
77+
expected_predictions = [
78+
PredictionResult(ex, pred) for ex,
79+
pred in zip(
80+
examples,
81+
torch.Tensor([example * 2.0 + 0.5
82+
for example in examples]).reshape(-1, 1))
83+
]
84+
85+
model = PytorchLinearRegression(input_dim=1, output_dim=1)
86+
model.load_state_dict(
87+
OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
88+
('linear.bias', torch.Tensor([0.5]))]))
89+
model.eval()
90+
91+
inference_runner = PytorchInferenceRunner(torch.device('cpu'))
92+
predictions = inference_runner.run_inference(examples, model)
93+
for actual, expected in zip(predictions, expected_predictions):
94+
self.assertTrue(_compare_prediction_result(actual, expected))
95+
96+
def test_inference_runner_multiple_tensor_features(self):
97+
examples = torch.from_numpy(
98+
np.array([1, 5, 3, 10, -14, 0, 0.5, 0.5],
99+
dtype="float32")).reshape(-1, 2)
100+
examples = [
101+
torch.from_numpy(np.array([1, 5], dtype="float32")),
102+
torch.from_numpy(np.array([3, 10], dtype="float32")),
103+
torch.from_numpy(np.array([-14, 0], dtype="float32")),
104+
torch.from_numpy(np.array([0.5, 0.5], dtype="float32")),
105+
]
106+
expected_predictions = [
107+
PredictionResult(ex, pred) for ex,
108+
pred in zip(
109+
examples,
110+
torch.Tensor([f1 * 2.0 + f2 * 3 + 0.5
111+
for f1, f2 in examples]).reshape(-1, 1))
112+
]
113+
114+
model = PytorchLinearRegression(input_dim=2, output_dim=1)
115+
model.load_state_dict(
116+
OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])),
117+
('linear.bias', torch.Tensor([0.5]))]))
118+
model.eval()
119+
120+
inference_runner = PytorchInferenceRunner(torch.device('cpu'))
121+
predictions = inference_runner.run_inference(examples, model)
122+
for actual, expected in zip(predictions, expected_predictions):
123+
self.assertTrue(_compare_prediction_result(actual, expected))
124+
125+
def test_num_bytes(self):
126+
inference_runner = PytorchInferenceRunner(torch.device('cpu'))
127+
examples = torch.from_numpy(
128+
np.array([1, 5, 3, 10, -14, 0, 0.5, 0.5],
129+
dtype="float32")).reshape(-1, 2)
130+
self.assertEqual((examples[0].element_size()) * 8,
131+
inference_runner.get_num_bytes(examples))
132+
133+
def test_namespace(self):
134+
inference_runner = PytorchInferenceRunner(torch.device('cpu'))
135+
self.assertEqual(
136+
'RunInferencePytorch', inference_runner.get_metrics_namespace())
137+
138+
def test_pipeline_local_model(self):
139+
with TestPipeline() as pipeline:
140+
examples = torch.from_numpy(
141+
np.array([1, 5, 3, 10, -14, 0, 0.5, 0.5],
142+
dtype="float32")).reshape(-1, 2)
143+
expected_predictions = [
144+
PredictionResult(ex, pred) for ex,
145+
pred in zip(
146+
examples,
147+
torch.Tensor([f1 * 2.0 + f2 * 3 + 0.5
148+
for f1, f2 in examples]).reshape(-1, 1))
149+
]
150+
151+
state_dict = OrderedDict([('linear.weight', torch.Tensor([[2.0, 3]])),
152+
('linear.bias', torch.Tensor([0.5]))])
153+
path = os.path.join(self.tmpdir, 'my_state_dict_path')
154+
torch.save(state_dict, path)
155+
156+
model_loader = PytorchModelLoader(
157+
state_dict_path=path,
158+
model_class=PytorchLinearRegression(input_dim=2, output_dim=1))
159+
160+
pcoll = pipeline | 'start' >> beam.Create(examples)
161+
predictions = pcoll | RunInference(model_loader)
162+
assert_that(
163+
predictions,
164+
equal_to(expected_predictions, equals_fn=_compare_prediction_result))
165+
166+
def test_pipeline_gcs_model(self):
167+
with TestPipeline() as pipeline:
168+
examples = torch.from_numpy(
169+
np.array([1, 5, 3, 10], dtype="float32").reshape(-1, 1))
170+
expected_predictions = [
171+
PredictionResult(ex, pred) for ex,
172+
pred in zip(
173+
examples,
174+
torch.Tensor([example * 2.0 + 0.5
175+
for example in examples]).reshape(-1, 1))
176+
]
177+
178+
gs_pth = 'gs://apache-beam-ml/pytorch_lin_reg_model_2x+0.5_state_dict.pth'
179+
model_loader = PytorchModelLoader(
180+
state_dict_path=gs_pth,
181+
model_class=PytorchLinearRegression(input_dim=1, output_dim=1))
182+
183+
pcoll = pipeline | 'start' >> beam.Create(examples)
184+
predictions = pcoll | RunInference(model_loader)
185+
assert_that(
186+
predictions,
187+
equal_to(expected_predictions, equals_fn=_compare_prediction_result))
188+
189+
def test_invalid_input_type(self):
190+
with self.assertRaisesRegex(TypeError, "expected Tensor as element"):
191+
with TestPipeline() as pipeline:
192+
examples = np.array([1, 5, 3, 10], dtype="float32").reshape(-1, 1)
193+
194+
state_dict = OrderedDict([('linear.weight', torch.Tensor([[2.0]])),
195+
('linear.bias', torch.Tensor([0.5]))])
196+
path = os.path.join(self.tmpdir, 'my_state_dict_path')
197+
torch.save(state_dict, path)
198+
199+
model_loader = PytorchModelLoader(
200+
state_dict_path=path,
201+
model_class=PytorchLinearRegression(input_dim=1, output_dim=1))
202+
203+
pcoll = pipeline | 'start' >> beam.Create(examples)
204+
# pylint: disable=expression-not-assigned
205+
pcoll | RunInference(model_loader)
206+
207+
208+
if __name__ == '__main__':
209+
unittest.main()

sdks/python/container/license_scripts/dep_urls_py.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,5 +145,7 @@ pip_dependencies:
145145
license: "https://raw.githubusercontent.com/PAIR-code/what-if-tool/master/LICENSE"
146146
timeloop:
147147
license: "https://raw.githubusercontent.com/sankalpjonn/timeloop/master/LICENSE"
148+
torch:
149+
license: "https://raw.githubusercontent.com/pytorch/pytorch/master/LICENSE"
148150
wget:
149151
license: "https://raw.githubusercontent.com/mirror/wget/master/COPYING"

sdks/python/pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ markers =
4545
no_xdist: run without pytest-xdist plugin
4646
# We run these tests with multiple major pyarrow versions (BEAM-11211)
4747
uses_pyarrow: tests that utilize pyarrow in some way
48+
uses_pytorch: tests that utilize pytorch in some way
4849

4950
# Default timeout intended for unit tests.
5051
# If certain tests need a different value, please see the docs on how to

sdks/python/test-suites/tox/py38/build.gradle

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@ toxTask "testPy38pandas-14", "py38-pandas-14"
8585
test.dependsOn "testPy38pandas-14"
8686
preCommitPy38.dependsOn "testPy38pandas-14"
8787

88+
// Create a test task for each minor version of pytorch
89+
toxTask "testPy38pytorch-19", "py38-pytorch-19"
90+
test.dependsOn "testPy38pytorch-19"
91+
preCommitPy38.dependsOn "testPy38pytorch-19"
92+
93+
toxTask "testPy38pytorch-110", "py38-pytorch-110"
94+
test.dependsOn "testPy38pytorch-110"
95+
preCommitPy38.dependsOn "testPy38pytorch-110"
96+
8897
toxTask "whitespacelint", "whitespacelint"
8998

9099
task archiveFilesToLint(type: Zip) {

sdks/python/tox.ini

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,15 @@ commands =
264264
/bin/sh -c "pip freeze | grep -E '(pandas|numpy)'"
265265
# Run all DataFrame API unit tests
266266
{toxinidir}/scripts/run_pytest.sh {envname} 'apache_beam/dataframe'
267+
268+
[testenv:py{37,38,39}-pytorch-{19,110}]
269+
deps =
270+
-r build-requirements.txt
271+
19: torch>=1.9.0,<1.10.0
272+
110: torch>=1.10.0,<1.11.0
273+
extras = test,gcp
274+
commands =
275+
# Log torch version for debugging
276+
/bin/sh -c "pip freeze | grep -E torch"
277+
# Run all PyTorch unit tests
278+
pytest -o junit_suite_name={envname} --junitxml=pytest_{envname}.xml -n 6 -m uses_pytorch {posargs}

0 commit comments

Comments
 (0)