From 17ac4b164343334c055a0b193abe9b6176ab34e6 Mon Sep 17 00:00:00 2001 From: Ashish Gupta Date: Tue, 22 Sep 2020 11:18:03 -0700 Subject: [PATCH 1/3] add inferentia pytorch config --- .../image_uri_config/inferentia-pytorch.json | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 src/sagemaker/image_uri_config/inferentia-pytorch.json diff --git a/src/sagemaker/image_uri_config/inferentia-pytorch.json b/src/sagemaker/image_uri_config/inferentia-pytorch.json new file mode 100644 index 0000000000..6cecbc3cc0 --- /dev/null +++ b/src/sagemaker/image_uri_config/inferentia-pytorch.json @@ -0,0 +1,14 @@ +{ + "processors": ["inf"], + "scope": ["inference"], + "versions": { + "1.5.1": { + "py_versions": ["py3"], + "registries": { + "us-east-1": "785573368785", + "us-west-2": "301217895009" + }, + "repository": "sagemaker-neo-pytorch" + } + } +} From e608acf272cfa8a966fb8a73ab478486fe288c34 Mon Sep 17 00:00:00 2001 From: Ashish Gupta Date: Tue, 22 Sep 2020 11:56:24 -0700 Subject: [PATCH 2/3] add test --- tests/conftest.py | 1 + tests/unit/sagemaker/image_uris/test_neo.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index a23f7e8e83..d74d953543 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,6 +44,7 @@ "coach_tensorflow", "inferentia_mxnet", "inferentia_tensorflow", + "inferentia_pytorch", "mxnet", "neo_mxnet", "neo_pytorch", diff --git a/tests/unit/sagemaker/image_uris/test_neo.py b/tests/unit/sagemaker/image_uris/test_neo.py index 9347e046a0..e427bf885d 100644 --- a/tests/unit/sagemaker/image_uris/test_neo.py +++ b/tests/unit/sagemaker/image_uris/test_neo.py @@ -117,6 +117,8 @@ def test_inferentia_mxnet(inferentia_mxnet_version): def test_inferentia_tensorflow(inferentia_tensorflow_version): _test_inferentia_framework_uris("tensorflow", inferentia_tensorflow_version) +def test_inferentia_pytorch(inferentia_pytorch_version): + _test_inferentia_framework_uris("pytorch", inferentia_pytorch_version) def _expected_framework_uri(framework, version, region="us-west-2", processor="cpu"): return expected_uris.framework_uri( From 710d4e9b44baee7da5d0ef06d0458804a847368e Mon Sep 17 00:00:00 2001 From: Ashish Gupta Date: Tue, 22 Sep 2020 12:06:06 -0700 Subject: [PATCH 3/3] fix styling --- tests/unit/sagemaker/image_uris/test_neo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/sagemaker/image_uris/test_neo.py b/tests/unit/sagemaker/image_uris/test_neo.py index e427bf885d..474e996086 100644 --- a/tests/unit/sagemaker/image_uris/test_neo.py +++ b/tests/unit/sagemaker/image_uris/test_neo.py @@ -117,9 +117,11 @@ def test_inferentia_mxnet(inferentia_mxnet_version): def test_inferentia_tensorflow(inferentia_tensorflow_version): _test_inferentia_framework_uris("tensorflow", inferentia_tensorflow_version) + def test_inferentia_pytorch(inferentia_pytorch_version): _test_inferentia_framework_uris("pytorch", inferentia_pytorch_version) + def _expected_framework_uri(framework, version, region="us-west-2", processor="cpu"): return expected_uris.framework_uri( "sagemaker-{}".format(framework),