From 917828e7edd71465a4275e84ee0a485730a3145a Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Thu, 4 Apr 2019 11:34:58 -0700 Subject: [PATCH 1/2] change: add csv deserializer --- src/sagemaker/predictor.py | 14 ++++++++++++++ tests/unit/test_predictor.py | 15 +++++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 5da69dfb2f..820e0a691e 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -229,6 +229,20 @@ def _row_to_csv(obj): return ','.join(obj) +class _CsvDeserializer(object): + def __init__(self): + self.content_type = CONTENT_TYPE_CSV + + def __call__(self, stream, content_type): + try: + return list(csv.reader(stream.read().decode('utf-8').splitlines())) + finally: + stream.close() + + +csv_deserializer = _CsvDeserializer() + + class BytesDeserializer(object): """Return the response as an undecoded array of bytes. diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index 2b1434c580..ba1385d4ae 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -21,8 +21,9 @@ import numpy as np from sagemaker.predictor import RealTimePredictor -from sagemaker.predictor import json_serializer, json_deserializer, csv_serializer, BytesDeserializer, \ - StringDeserializer, StreamDeserializer, numpy_deserializer, npy_serializer, _NumpyDeserializer +from sagemaker.predictor import json_serializer, json_deserializer, csv_serializer, \ + csv_deserializer, BytesDeserializer, StringDeserializer, StreamDeserializer, \ + numpy_deserializer, npy_serializer, _NumpyDeserializer from tests.unit import DATA_DIR # testing serialization functions @@ -141,6 +142,16 @@ def test_csv_serializer_csv_reader(): assert result == validation_data +def test_csv_deserializer_array(): + result = csv_deserializer(io.BytesIO(b'1,2,3'), 'text/csv') + assert result == [['1', '2', '3']] + + +def test_csv_deserializer_2dimensional(): + result = csv_deserializer(io.BytesIO(b'1,2,3\n3,4,5'), 'text/csv') + assert result == [['1', '2', '3'], ['3', '4', '5']] + + def test_json_deserializer_array(): result = json_deserializer(io.BytesIO(b'[1, 2, 3]'), 'application/json') From 1249f53fab0629b6084cc37f0e34cd10b5fb20b3 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Thu, 4 Apr 2019 13:48:06 -0700 Subject: [PATCH 2/2] Address PR comments --- src/sagemaker/predictor.py | 7 ++++--- tests/unit/test_predictor.py | 5 +++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 820e0a691e..10edbfe727 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -230,12 +230,13 @@ def _row_to_csv(obj): class _CsvDeserializer(object): - def __init__(self): - self.content_type = CONTENT_TYPE_CSV + def __init__(self, encoding='utf-8'): + self.accept = CONTENT_TYPE_CSV + self.encoding = encoding def __call__(self, stream, content_type): try: - return list(csv.reader(stream.read().decode('utf-8').splitlines())) + return list(csv.reader(stream.read().decode(self.encoding).splitlines())) finally: stream.close() diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index ba1385d4ae..7be1e51a48 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -142,6 +142,11 @@ def test_csv_serializer_csv_reader(): assert result == validation_data +def test_csv_deserializer_single_element(): + result = csv_deserializer(io.BytesIO(b'1'), 'text/csv') + assert result == [['1']] + + def test_csv_deserializer_array(): result = csv_deserializer(io.BytesIO(b'1,2,3'), 'text/csv') assert result == [['1', '2', '3']]