From ab85b5d0788f81b19c96395c738a0ad7f4554efc Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Wed, 15 Jul 2020 18:36:01 +0100 Subject: [PATCH 1/4] Create test for sets serialisation support in XCom Relates to issue: #8703 --- tests/models/test_xcom.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py index 39586d1638468..887ef477d8ec6 100644 --- a/tests/models/test_xcom.py +++ b/tests/models/test_xcom.py @@ -212,3 +212,31 @@ def test_xcom_get_many(self): for result in results: self.assertEqual(result.value, json_obj) + + @conf_vars({("core", "xcom_enable_pickling"): "False"}) + def test_xcom_enable_set_type(self): + set_obj = {"set-value1", "set-value2"} + execution_date = timezone.utcnow() + key = "xcom_test5" + dag_id = "test_dag5" + task_id = "test_task5" + XCom.set(key=key, + value=set_obj, + dag_id=dag_id, + task_id=task_id, + execution_date=execution_date) + + ret_value = XCom.get_one(key=key, + dag_id=dag_id, + task_id=task_id, + execution_date=execution_date) + + self.assertEqual(ret_value, set_obj) + + session = settings.Session() + ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id, + XCom.task_id == task_id, + XCom.execution_date == execution_date + ).first().value + + self.assertEqual(ret_value, set_obj) From f3b75dc2f4ca1b4317d73a64543fa21dbbda9f02 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 16 Jul 2020 00:15:50 +0100 Subject: [PATCH 2/4] Create JSON serialisation module Based on airflow/serialization/serialized_objects.py:BaseSerialization The purpose was to create a lightweight module which could be used by XCom in order to serialise sets and other structures which contain nested sets. Relates to issue: #8703 --- airflow/serialization/json.py | 103 +++++++++++++++++ tests/serialization/test_json.py | 183 +++++++++++++++++++++++++++++++ 2 files changed, 286 insertions(+) create mode 100644 airflow/serialization/json.py create mode 100644 tests/serialization/test_json.py diff --git a/airflow/serialization/json.py b/airflow/serialization/json.py new file mode 100644 index 0000000000000..e6fa932bf4d5c --- /dev/null +++ b/airflow/serialization/json.py @@ -0,0 +1,103 @@ +import enum +import json +from typing import Any, Union + +from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding + + +def encode(value: Any, type_: Any) -> [Encoding, Any]: + "Encode value and type into a JSON dict." + return {Encoding.VAR: value, Encoding.TYPE: type_} + + +_return_primitive = lambda value: value + + +def _serialize_list(list_: list) -> list: + return [serialize(item) for item in list_] + + +def _serialize_dict(dict_: dict) -> dict: + value = {k: serialize(v) for k,v in dict_.items()} + return encode( + value=value, + type_=DAT.DICT + ) + + +def _serialize_set(set_: set) -> dict: + value = [serialize(item) for item in set_] + return encode( + value=value, + type_=DAT.SET + ) + + +def _serialize_tuple(tupe_: tuple) -> dict: + value = [serialize(item) for item in tuple_] + return encode( + value=value, + type_=DAT.TUPLE + ) + + +def serialize(value: Any) -> Any: + "Serialize a value so it can be stored as JSON." + serialization_function_by_type = { + int: _return_primitive, + bool: _return_primitive, + float: _return_primitive, + str: _return_primitive, + list: _serialize_list, + dict: _serialize_dict, + set: _serialize_set, + tuple: _serialize_tuple + } + try: + function = serialization_function_by_type[type(value)] + except KeyError: + raise TypeError(f"Unable to serialize {type(value)}") + return function(value) + + +def _deserialize_list(list_: list) -> list: + return [deserialize(item) for item in list_] + + +def _deserialize_dict(dict_: dict) -> dict: + type_ = dict_[Encoding.TYPE] + value = dict_[Encoding.VAR] + if type_ == DAT.DICT: + value = {k: deserialize(v) for k,v in value.items()} + elif type_ == DAT.SET: + value = _deserialize_set(value) + elif type_ == DAT.TUPLE: + value = _deserialize_tuple() + else: + raise TypeError(f"Unable to deserialize dict of {Encoding.TYPE}: {type_}") + return value + + +def _deserialize_set(set_: set) -> set: + return set([deserialize(item) for item in set_]) + + +def _deserialize_tuple(tuple_: tuple) -> tuple: + return tuple(deserialize(item) for item in tuple_) + + +def deserialize(value: Any) -> Any: + "Deserialize a JSON-compatible value into its original value." + deserialization_function_by_type = { + int: _return_primitive, + bool: _return_primitive, + float: _return_primitive, + str: _return_primitive, + list: _deserialize_list, + dict: _deserialize_dict + } + try: + function = deserialization_function_by_type[type(value)] + except KeyError: + raise ValueError(f"Unable to deserialize {type(value)}") + return function(value) diff --git a/tests/serialization/test_json.py b/tests/serialization/test_json.py new file mode 100644 index 0000000000000..c7d926dcfed97 --- /dev/null +++ b/tests/serialization/test_json.py @@ -0,0 +1,183 @@ +import enum +import unittest + +import pytest + +from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding +from airflow.serialization.json import deserialize, serialize + + +class TestSerialize(unittest.TestCase): + + def test_integer(self): + assert serialize(1) == 1 + assert serialize(300) == 300 + + def test_boolean(self): + assert serialize(True) == True + assert serialize(False) == False + + def test_string(self): + assert serialize("string") == "string" + + def test_float(self): + assert serialize(0.1) == 0.1 + + def test_list_of_primitives(self): + assert serialize([1,"a",True]) == [1,"a", True] + + def test_list_of_sets(self): + values = [ + {1}, + {2} + ] + expected_serialization = [ + { + Encoding.VAR: [1], + Encoding.TYPE: DAT.SET + }, + { + Encoding.VAR: [2], + Encoding.TYPE: DAT.SET + }, + ] + assert serialize(values) == expected_serialization + + def test_dictionary_of_primitives(self): + value = {"key": "value"} + expected_serialization = { + Encoding.VAR: {"key": "value"}, + Encoding.TYPE: DAT.DICT + } + assert serialize(value) == expected_serialization + + def test_dictionary_of_list_of_sets(self): + value = { + "key_to_list": [ + set(["a"]), + set(["b"]) + ] + } + expected_serialization = { + Encoding.VAR: { + "key_to_list": [ + { + Encoding.VAR: ["a"], + Encoding.TYPE: DAT.SET + }, + { + Encoding.VAR: ["b"], + Encoding.TYPE: DAT.SET + } + ] + }, + Encoding.TYPE: DAT.DICT + } + assert serialize(value) == expected_serialization + + def test_set(self): + value = {1, 2, 3} + expected_serialization = { + Encoding.VAR: [1, 2, 3], + Encoding.TYPE: DAT.SET + } + assert serialize(value) == expected_serialization + + def test_enum_raises_exception(self): + value = Encoding.VAR + with pytest.raises(TypeError) as err: + serialize(value) + assert err.value.args[0] == "Unable to serialize " + + +class TestDeserialize(unittest.TestCase): + + def test_integer(self): + assert deserialize(1) == 1 + assert deserialize(300) == 300 + + def test_boolean(self): + assert deserialize(True) == True + assert deserialize(False) == False + + def test_string(self): + assert deserialize("string") == "string" + + def test_float(self): + assert deserialize(0.1) == 0.1 + + def test_list_of_primitives(self): + assert deserialize([1,"a",True]) == [1,"a", True] + + def test_list_of_sets(self): + values = [ + { + Encoding.VAR: [1], + Encoding.TYPE: DAT.SET + }, + { + Encoding.VAR: [2], + Encoding.TYPE: DAT.SET + } + ] + expected_deserialization = [ + {1}, + {2} + ] + assert deserialize(values) == expected_deserialization + + def test_dictionary_of_primitives(self): + value = { + Encoding.VAR: {"key": "value"}, + Encoding.TYPE: DAT.DICT + } + expected_deserialization = {"key": "value"} + + assert deserialize(value) == expected_deserialization + + def test_dictionary_of_list_of_sets(self): + value = { + Encoding.VAR: { + "key_to_list": [ + { + Encoding.VAR: ["a"], + Encoding.TYPE: DAT.SET + }, + { + Encoding.VAR: ["b"], + Encoding.TYPE: DAT.SET + } + ] + }, + Encoding.TYPE: DAT.DICT + } + expected_deserialization = { + "key_to_list": [ + set(["a"]), + set(["b"]) + ] + } + assert deserialize(value) == expected_deserialization + + def test_set(self): + value = { + Encoding.VAR: [1, 2, 3], + Encoding.TYPE: DAT.SET + } + expected_deserialization = {1, 2, 3} + assert deserialize(value) == expected_deserialization + + def test_unsupported_value(self): + value = Encoding.VAR + with pytest.raises(ValueError) as err: + deserialize(value) + assert err.value.args[0] == "Unable to deserialize " + + def test_unsupported_dict_type(self): + value = { + Encoding.VAR: [1, 2, 3], + Encoding.TYPE: DAT.DAG + } + with pytest.raises(TypeError) as err: + deserialize(value) + assert err.value.args[0] == "Unable to deserialize dict of __type: dag" From f6dd276100dfe7f2a767c65669cf7bb577441e80 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 16 Jul 2020 00:17:13 +0100 Subject: [PATCH 3/4] Add support to using sets in XCom when using JSON serialization. Resolves: #8703 --- airflow/models/xcom.py | 7 ++- airflow/serialization/json.py | 78 ++++++++++++++++++++------------ tests/models/test_xcom.py | 30 +++++++++++- tests/serialization/test_json.py | 55 +++++++++++++++++----- 4 files changed, 127 insertions(+), 43 deletions(-) diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 05cd8d610b52a..5982d70fbd4a6 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -28,6 +28,7 @@ from airflow.configuration import conf from airflow.models.base import COLLATION_ARGS, ID_LEN, Base +from airflow.serialization.json import deserialize, serialize from airflow.utils import timezone from airflow.utils.helpers import is_container from airflow.utils.log.logging_mixin import LoggingMixin @@ -251,7 +252,8 @@ def serialize_value(value: Any): if conf.getboolean('core', 'enable_xcom_pickling'): return pickle.dumps(value) try: - return json.dumps(value).encode('UTF-8') + dict_ = serialize(value) + return json.dumps(dict_).encode('UTF-8') except (ValueError, TypeError): log.error("Could not serialize the XCOM value into JSON. " "If you are using pickles instead of JSON " @@ -269,7 +271,8 @@ def deserialize_value(result) -> Any: return pickle.loads(result.value) try: - return json.loads(result.value.decode('UTF-8')) + dict_ = json.loads(result.value.decode('UTF-8')) + return deserialize(dict_) except JSONDecodeError: log.error("Could not deserialize the XCOM value from JSON. " "If you are using pickles instead of JSON " diff --git a/airflow/serialization/json.py b/airflow/serialization/json.py index e6fa932bf4d5c..a73c1297dfd69 100644 --- a/airflow/serialization/json.py +++ b/airflow/serialization/json.py @@ -1,12 +1,28 @@ -import enum -import json -from typing import Any, Union +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding -def encode(value: Any, type_: Any) -> [Encoding, Any]: - "Encode value and type into a JSON dict." +def encode(value: Any, type_: Any) -> dict: + """Encode value and type into a JSON-friendly dict.""" return {Encoding.VAR: value, Encoding.TYPE: type_} @@ -18,31 +34,29 @@ def _serialize_list(list_: list) -> list: def _serialize_dict(dict_: dict) -> dict: - value = {k: serialize(v) for k,v in dict_.items()} - return encode( - value=value, - type_=DAT.DICT - ) + value = {k: serialize(v) for k, v in dict_.items()} + return encode(value=value, type_=DAT.DICT) def _serialize_set(set_: set) -> dict: value = [serialize(item) for item in set_] - return encode( - value=value, - type_=DAT.SET - ) + return encode(value=value, type_=DAT.SET) -def _serialize_tuple(tupe_: tuple) -> dict: +def _serialize_tuple(tuple_: tuple) -> dict: value = [serialize(item) for item in tuple_] - return encode( - value=value, - type_=DAT.TUPLE - ) + return encode(value=value, type_=DAT.TUPLE) -def serialize(value: Any) -> Any: - "Serialize a value so it can be stored as JSON." +def serialize(value): + """Serialize a value so it can be stored as JSON. + + The serialization protocol is: + + (1) keep JSON supported types: primitives, dict, list; + (2) encode other types such as tuples and sets as + ``{TYPE: 'foo', VAR: 'bar'}`` + """ serialization_function_by_type = { int: _return_primitive, bool: _return_primitive, @@ -53,7 +67,7 @@ def serialize(value: Any) -> Any: set: _serialize_set, tuple: _serialize_tuple } - try: + try: function = serialization_function_by_type[type(value)] except KeyError: raise TypeError(f"Unable to serialize {type(value)}") @@ -68,26 +82,34 @@ def _deserialize_dict(dict_: dict) -> dict: type_ = dict_[Encoding.TYPE] value = dict_[Encoding.VAR] if type_ == DAT.DICT: - value = {k: deserialize(v) for k,v in value.items()} + value = {k: deserialize(v) for k, v in value.items()} elif type_ == DAT.SET: value = _deserialize_set(value) elif type_ == DAT.TUPLE: - value = _deserialize_tuple() + value = _deserialize_tuple(value) else: raise TypeError(f"Unable to deserialize dict of {Encoding.TYPE}: {type_}") return value def _deserialize_set(set_: set) -> set: - return set([deserialize(item) for item in set_]) + return {deserialize(item) for item in set_} def _deserialize_tuple(tuple_: tuple) -> tuple: return tuple(deserialize(item) for item in tuple_) -def deserialize(value: Any) -> Any: - "Deserialize a JSON-compatible value into its original value." +def deserialize(value): + """Deserialize a JSON-compatible value into its original value. + + The deserialization protocol is: + + (1) keep JSON supported types: primitives, lists + (2) decode other types which might have been encoded in dicts using + ``{TYPE: 'foo', VAR: 'bar'}`` + """ + deserialization_function_by_type = { int: _return_primitive, bool: _return_primitive, @@ -96,7 +118,7 @@ def deserialize(value: Any) -> Any: list: _deserialize_list, dict: _deserialize_dict } - try: + try: function = deserialization_function_by_type[type(value)] except KeyError: raise ValueError(f"Unable to deserialize {type(value)}") diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py index 887ef477d8ec6..3211c8f2b905f 100644 --- a/tests/models/test_xcom.py +++ b/tests/models/test_xcom.py @@ -214,8 +214,8 @@ def test_xcom_get_many(self): self.assertEqual(result.value, json_obj) @conf_vars({("core", "xcom_enable_pickling"): "False"}) - def test_xcom_enable_set_type(self): - set_obj = {"set-value1", "set-value2"} + def test_xcom_json_serialization_supports_set(self): + set_obj = set(["set-value1", "set-value2"]) execution_date = timezone.utcnow() key = "xcom_test5" dag_id = "test_dag5" @@ -238,5 +238,31 @@ def test_xcom_enable_set_type(self): XCom.task_id == task_id, XCom.execution_date == execution_date ).first().value + self.assertEqual(ret_value, set_obj) + @conf_vars({("core", "xcom_enable_pickling"): "False"}) + def test_xcom_json_serialization_supports_nested_sets(self): + set_obj = {"dict_key": set(["set-value"])} + execution_date = timezone.utcnow() + key = "xcom_test6" + dag_id = "test_dag6" + task_id = "test_task6" + XCom.set(key=key, + value=set_obj, + dag_id=dag_id, + task_id=task_id, + execution_date=execution_date) + + ret_value = XCom.get_one(key=key, + dag_id=dag_id, + task_id=task_id, + execution_date=execution_date) + + self.assertEqual(ret_value, set_obj) + + session = settings.Session() + ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id, + XCom.task_id == task_id, + XCom.execution_date == execution_date + ).first().value self.assertEqual(ret_value, set_obj) diff --git a/tests/serialization/test_json.py b/tests/serialization/test_json.py index c7d926dcfed97..5241a964c0ff2 100644 --- a/tests/serialization/test_json.py +++ b/tests/serialization/test_json.py @@ -1,4 +1,21 @@ -import enum +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import unittest import pytest @@ -14,8 +31,8 @@ def test_integer(self): assert serialize(300) == 300 def test_boolean(self): - assert serialize(True) == True - assert serialize(False) == False + assert serialize(True) + assert not serialize(False) def test_string(self): assert serialize("string") == "string" @@ -24,7 +41,7 @@ def test_float(self): assert serialize(0.1) == 0.1 def test_list_of_primitives(self): - assert serialize([1,"a",True]) == [1,"a", True] + assert serialize([1, "a", True]) == [1, "a", True] def test_list_of_sets(self): values = [ @@ -83,6 +100,14 @@ def test_set(self): } assert serialize(value) == expected_serialization + def test_tuple(self): + value = (1, 2, 3) + expected_serialization = { + Encoding.VAR: [1, 2, 3], + Encoding.TYPE: DAT.TUPLE + } + assert serialize(value) == expected_serialization + def test_enum_raises_exception(self): value = Encoding.VAR with pytest.raises(TypeError) as err: @@ -97,8 +122,8 @@ def test_integer(self): assert deserialize(300) == 300 def test_boolean(self): - assert deserialize(True) == True - assert deserialize(False) == False + assert deserialize(True) + assert not deserialize(False) def test_string(self): assert deserialize("string") == "string" @@ -107,7 +132,7 @@ def test_float(self): assert deserialize(0.1) == 0.1 def test_list_of_primitives(self): - assert deserialize([1,"a",True]) == [1,"a", True] + assert deserialize([1, "a", True]) == [1, "a", True] def test_list_of_sets(self): values = [ @@ -132,7 +157,7 @@ def test_dictionary_of_primitives(self): Encoding.TYPE: DAT.DICT } expected_deserialization = {"key": "value"} - + assert deserialize(value) == expected_deserialization def test_dictionary_of_list_of_sets(self): @@ -164,14 +189,22 @@ def test_set(self): Encoding.VAR: [1, 2, 3], Encoding.TYPE: DAT.SET } - expected_deserialization = {1, 2, 3} + expected_deserialization = {1, 2, 3} + assert deserialize(value) == expected_deserialization + + def test_tuple(self): + value = { + Encoding.VAR: [1, 2, 3], + Encoding.TYPE: DAT.TUPLE + } + expected_deserialization = (1, 2, 3) assert deserialize(value) == expected_deserialization def test_unsupported_value(self): value = Encoding.VAR with pytest.raises(ValueError) as err: deserialize(value) - assert err.value.args[0] == "Unable to deserialize " + assert err.value.args[0] == "Unable to deserialize " def test_unsupported_dict_type(self): value = { @@ -180,4 +213,4 @@ def test_unsupported_dict_type(self): } with pytest.raises(TypeError) as err: deserialize(value) - assert err.value.args[0] == "Unable to deserialize dict of __type: dag" + assert err.value.args[0] == "Unable to deserialize dict of __type: dag" From 4be9c0f341b05762a7677fb53087aba1b03c6d3a Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 21 Jul 2020 01:10:15 +0100 Subject: [PATCH 4/4] Change serialization function to support NoneType --- airflow/serialization/json.py | 2 ++ tests/serialization/test_json.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/airflow/serialization/json.py b/airflow/serialization/json.py index a73c1297dfd69..f05c9af64d3b1 100644 --- a/airflow/serialization/json.py +++ b/airflow/serialization/json.py @@ -58,6 +58,7 @@ def serialize(value): ``{TYPE: 'foo', VAR: 'bar'}`` """ serialization_function_by_type = { + type(None): _return_primitive, int: _return_primitive, bool: _return_primitive, float: _return_primitive, @@ -111,6 +112,7 @@ def deserialize(value): """ deserialization_function_by_type = { + type(None): _return_primitive, int: _return_primitive, bool: _return_primitive, float: _return_primitive, diff --git a/tests/serialization/test_json.py b/tests/serialization/test_json.py index 5241a964c0ff2..4dd7fc0110775 100644 --- a/tests/serialization/test_json.py +++ b/tests/serialization/test_json.py @@ -26,6 +26,9 @@ class TestSerialize(unittest.TestCase): + def test_none(self): + assert serialize(None) is None + def test_integer(self): assert serialize(1) == 1 assert serialize(300) == 300 @@ -117,6 +120,9 @@ def test_enum_raises_exception(self): class TestDeserialize(unittest.TestCase): + def test_none(self): + assert deserialize(None) is None + def test_integer(self): assert deserialize(1) == 1 assert deserialize(300) == 300