diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 05cd8d610b52a..5982d70fbd4a6 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -28,6 +28,7 @@ from airflow.configuration import conf from airflow.models.base import COLLATION_ARGS, ID_LEN, Base +from airflow.serialization.json import deserialize, serialize from airflow.utils import timezone from airflow.utils.helpers import is_container from airflow.utils.log.logging_mixin import LoggingMixin @@ -251,7 +252,8 @@ def serialize_value(value: Any): if conf.getboolean('core', 'enable_xcom_pickling'): return pickle.dumps(value) try: - return json.dumps(value).encode('UTF-8') + dict_ = serialize(value) + return json.dumps(dict_).encode('UTF-8') except (ValueError, TypeError): log.error("Could not serialize the XCOM value into JSON. " "If you are using pickles instead of JSON " @@ -269,7 +271,8 @@ def deserialize_value(result) -> Any: return pickle.loads(result.value) try: - return json.loads(result.value.decode('UTF-8')) + dict_ = json.loads(result.value.decode('UTF-8')) + return deserialize(dict_) except JSONDecodeError: log.error("Could not deserialize the XCOM value from JSON. " "If you are using pickles instead of JSON " diff --git a/airflow/serialization/json.py b/airflow/serialization/json.py new file mode 100644 index 0000000000000..f05c9af64d3b1 --- /dev/null +++ b/airflow/serialization/json.py @@ -0,0 +1,127 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any + +from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding + + +def encode(value: Any, type_: Any) -> dict: + """Encode value and type into a JSON-friendly dict.""" + return {Encoding.VAR: value, Encoding.TYPE: type_} + + +_return_primitive = lambda value: value + + +def _serialize_list(list_: list) -> list: + return [serialize(item) for item in list_] + + +def _serialize_dict(dict_: dict) -> dict: + value = {k: serialize(v) for k, v in dict_.items()} + return encode(value=value, type_=DAT.DICT) + + +def _serialize_set(set_: set) -> dict: + value = [serialize(item) for item in set_] + return encode(value=value, type_=DAT.SET) + + +def _serialize_tuple(tuple_: tuple) -> dict: + value = [serialize(item) for item in tuple_] + return encode(value=value, type_=DAT.TUPLE) + + +def serialize(value): + """Serialize a value so it can be stored as JSON. + + The serialization protocol is: + + (1) keep JSON supported types: primitives, dict, list; + (2) encode other types such as tuples and sets as + ``{TYPE: 'foo', VAR: 'bar'}`` + """ + serialization_function_by_type = { + type(None): _return_primitive, + int: _return_primitive, + bool: _return_primitive, + float: _return_primitive, + str: _return_primitive, + list: _serialize_list, + dict: _serialize_dict, + set: _serialize_set, + tuple: _serialize_tuple + } + try: + function = serialization_function_by_type[type(value)] + except KeyError: + raise TypeError(f"Unable to serialize {type(value)}") + return function(value) + + +def _deserialize_list(list_: list) -> list: + return [deserialize(item) for item in list_] + + +def _deserialize_dict(dict_: dict) -> dict: + type_ = dict_[Encoding.TYPE] + value = dict_[Encoding.VAR] + if type_ == DAT.DICT: + value = {k: deserialize(v) for k, v in value.items()} + elif type_ == DAT.SET: + value = _deserialize_set(value) + elif type_ == DAT.TUPLE: + value = _deserialize_tuple(value) + else: + raise TypeError(f"Unable to deserialize dict of {Encoding.TYPE}: {type_}") + return value + + +def _deserialize_set(set_: set) -> set: + return {deserialize(item) for item in set_} + + +def _deserialize_tuple(tuple_: tuple) -> tuple: + return tuple(deserialize(item) for item in tuple_) + + +def deserialize(value): + """Deserialize a JSON-compatible value into its original value. + + The deserialization protocol is: + + (1) keep JSON supported types: primitives, lists + (2) decode other types which might have been encoded in dicts using + ``{TYPE: 'foo', VAR: 'bar'}`` + """ + + deserialization_function_by_type = { + type(None): _return_primitive, + int: _return_primitive, + bool: _return_primitive, + float: _return_primitive, + str: _return_primitive, + list: _deserialize_list, + dict: _deserialize_dict + } + try: + function = deserialization_function_by_type[type(value)] + except KeyError: + raise ValueError(f"Unable to deserialize {type(value)}") + return function(value) diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py index 39586d1638468..3211c8f2b905f 100644 --- a/tests/models/test_xcom.py +++ b/tests/models/test_xcom.py @@ -212,3 +212,57 @@ def test_xcom_get_many(self): for result in results: self.assertEqual(result.value, json_obj) + + @conf_vars({("core", "xcom_enable_pickling"): "False"}) + def test_xcom_json_serialization_supports_set(self): + set_obj = set(["set-value1", "set-value2"]) + execution_date = timezone.utcnow() + key = "xcom_test5" + dag_id = "test_dag5" + task_id = "test_task5" + XCom.set(key=key, + value=set_obj, + dag_id=dag_id, + task_id=task_id, + execution_date=execution_date) + + ret_value = XCom.get_one(key=key, + dag_id=dag_id, + task_id=task_id, + execution_date=execution_date) + + self.assertEqual(ret_value, set_obj) + + session = settings.Session() + ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id, + XCom.task_id == task_id, + XCom.execution_date == execution_date + ).first().value + self.assertEqual(ret_value, set_obj) + + @conf_vars({("core", "xcom_enable_pickling"): "False"}) + def test_xcom_json_serialization_supports_nested_sets(self): + set_obj = {"dict_key": set(["set-value"])} + execution_date = timezone.utcnow() + key = "xcom_test6" + dag_id = "test_dag6" + task_id = "test_task6" + XCom.set(key=key, + value=set_obj, + dag_id=dag_id, + task_id=task_id, + execution_date=execution_date) + + ret_value = XCom.get_one(key=key, + dag_id=dag_id, + task_id=task_id, + execution_date=execution_date) + + self.assertEqual(ret_value, set_obj) + + session = settings.Session() + ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id, + XCom.task_id == task_id, + XCom.execution_date == execution_date + ).first().value + self.assertEqual(ret_value, set_obj) diff --git a/tests/serialization/test_json.py b/tests/serialization/test_json.py new file mode 100644 index 0000000000000..4dd7fc0110775 --- /dev/null +++ b/tests/serialization/test_json.py @@ -0,0 +1,222 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +import pytest + +from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding +from airflow.serialization.json import deserialize, serialize + + +class TestSerialize(unittest.TestCase): + + def test_none(self): + assert serialize(None) is None + + def test_integer(self): + assert serialize(1) == 1 + assert serialize(300) == 300 + + def test_boolean(self): + assert serialize(True) + assert not serialize(False) + + def test_string(self): + assert serialize("string") == "string" + + def test_float(self): + assert serialize(0.1) == 0.1 + + def test_list_of_primitives(self): + assert serialize([1, "a", True]) == [1, "a", True] + + def test_list_of_sets(self): + values = [ + {1}, + {2} + ] + expected_serialization = [ + { + Encoding.VAR: [1], + Encoding.TYPE: DAT.SET + }, + { + Encoding.VAR: [2], + Encoding.TYPE: DAT.SET + }, + ] + assert serialize(values) == expected_serialization + + def test_dictionary_of_primitives(self): + value = {"key": "value"} + expected_serialization = { + Encoding.VAR: {"key": "value"}, + Encoding.TYPE: DAT.DICT + } + assert serialize(value) == expected_serialization + + def test_dictionary_of_list_of_sets(self): + value = { + "key_to_list": [ + set(["a"]), + set(["b"]) + ] + } + expected_serialization = { + Encoding.VAR: { + "key_to_list": [ + { + Encoding.VAR: ["a"], + Encoding.TYPE: DAT.SET + }, + { + Encoding.VAR: ["b"], + Encoding.TYPE: DAT.SET + } + ] + }, + Encoding.TYPE: DAT.DICT + } + assert serialize(value) == expected_serialization + + def test_set(self): + value = {1, 2, 3} + expected_serialization = { + Encoding.VAR: [1, 2, 3], + Encoding.TYPE: DAT.SET + } + assert serialize(value) == expected_serialization + + def test_tuple(self): + value = (1, 2, 3) + expected_serialization = { + Encoding.VAR: [1, 2, 3], + Encoding.TYPE: DAT.TUPLE + } + assert serialize(value) == expected_serialization + + def test_enum_raises_exception(self): + value = Encoding.VAR + with pytest.raises(TypeError) as err: + serialize(value) + assert err.value.args[0] == "Unable to serialize " + + +class TestDeserialize(unittest.TestCase): + + def test_none(self): + assert deserialize(None) is None + + def test_integer(self): + assert deserialize(1) == 1 + assert deserialize(300) == 300 + + def test_boolean(self): + assert deserialize(True) + assert not deserialize(False) + + def test_string(self): + assert deserialize("string") == "string" + + def test_float(self): + assert deserialize(0.1) == 0.1 + + def test_list_of_primitives(self): + assert deserialize([1, "a", True]) == [1, "a", True] + + def test_list_of_sets(self): + values = [ + { + Encoding.VAR: [1], + Encoding.TYPE: DAT.SET + }, + { + Encoding.VAR: [2], + Encoding.TYPE: DAT.SET + } + ] + expected_deserialization = [ + {1}, + {2} + ] + assert deserialize(values) == expected_deserialization + + def test_dictionary_of_primitives(self): + value = { + Encoding.VAR: {"key": "value"}, + Encoding.TYPE: DAT.DICT + } + expected_deserialization = {"key": "value"} + + assert deserialize(value) == expected_deserialization + + def test_dictionary_of_list_of_sets(self): + value = { + Encoding.VAR: { + "key_to_list": [ + { + Encoding.VAR: ["a"], + Encoding.TYPE: DAT.SET + }, + { + Encoding.VAR: ["b"], + Encoding.TYPE: DAT.SET + } + ] + }, + Encoding.TYPE: DAT.DICT + } + expected_deserialization = { + "key_to_list": [ + set(["a"]), + set(["b"]) + ] + } + assert deserialize(value) == expected_deserialization + + def test_set(self): + value = { + Encoding.VAR: [1, 2, 3], + Encoding.TYPE: DAT.SET + } + expected_deserialization = {1, 2, 3} + assert deserialize(value) == expected_deserialization + + def test_tuple(self): + value = { + Encoding.VAR: [1, 2, 3], + Encoding.TYPE: DAT.TUPLE + } + expected_deserialization = (1, 2, 3) + assert deserialize(value) == expected_deserialization + + def test_unsupported_value(self): + value = Encoding.VAR + with pytest.raises(ValueError) as err: + deserialize(value) + assert err.value.args[0] == "Unable to deserialize " + + def test_unsupported_dict_type(self): + value = { + Encoding.VAR: [1, 2, 3], + Encoding.TYPE: DAT.DAG + } + with pytest.raises(TypeError) as err: + deserialize(value) + assert err.value.args[0] == "Unable to deserialize dict of __type: dag"