diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 5da69dfb2f..10edbfe727 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -229,6 +229,21 @@ def _row_to_csv(obj): return ','.join(obj) +class _CsvDeserializer(object): + def __init__(self, encoding='utf-8'): + self.accept = CONTENT_TYPE_CSV + self.encoding = encoding + + def __call__(self, stream, content_type): + try: + return list(csv.reader(stream.read().decode(self.encoding).splitlines())) + finally: + stream.close() + + +csv_deserializer = _CsvDeserializer() + + class BytesDeserializer(object): """Return the response as an undecoded array of bytes. diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index 2b1434c580..7be1e51a48 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -21,8 +21,9 @@ import numpy as np from sagemaker.predictor import RealTimePredictor -from sagemaker.predictor import json_serializer, json_deserializer, csv_serializer, BytesDeserializer, \ - StringDeserializer, StreamDeserializer, numpy_deserializer, npy_serializer, _NumpyDeserializer +from sagemaker.predictor import json_serializer, json_deserializer, csv_serializer, \ + csv_deserializer, BytesDeserializer, StringDeserializer, StreamDeserializer, \ + numpy_deserializer, npy_serializer, _NumpyDeserializer from tests.unit import DATA_DIR # testing serialization functions @@ -141,6 +142,21 @@ def test_csv_serializer_csv_reader(): assert result == validation_data +def test_csv_deserializer_single_element(): + result = csv_deserializer(io.BytesIO(b'1'), 'text/csv') + assert result == [['1']] + + +def test_csv_deserializer_array(): + result = csv_deserializer(io.BytesIO(b'1,2,3'), 'text/csv') + assert result == [['1', '2', '3']] + + +def test_csv_deserializer_2dimensional(): + result = csv_deserializer(io.BytesIO(b'1,2,3\n3,4,5'), 'text/csv') + assert result == [['1', '2', '3'], ['3', '4', '5']] + + def test_json_deserializer_array(): result = json_deserializer(io.BytesIO(b'[1, 2, 3]'), 'application/json')