Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from airflow.configuration import conf
from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
from airflow.serialization.json import deserialize, serialize
from airflow.utils import timezone
from airflow.utils.helpers import is_container
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -251,7 +252,8 @@ def serialize_value(value: Any):
if conf.getboolean('core', 'enable_xcom_pickling'):
return pickle.dumps(value)
try:
return json.dumps(value).encode('UTF-8')
dict_ = serialize(value)
return json.dumps(dict_).encode('UTF-8')
except (ValueError, TypeError):
log.error("Could not serialize the XCOM value into JSON. "
"If you are using pickles instead of JSON "
Expand All @@ -269,7 +271,8 @@ def deserialize_value(result) -> Any:
return pickle.loads(result.value)

try:
return json.loads(result.value.decode('UTF-8'))
dict_ = json.loads(result.value.decode('UTF-8'))
return deserialize(dict_)
except JSONDecodeError:
log.error("Could not deserialize the XCOM value from JSON. "
"If you are using pickles instead of JSON "
Expand Down
127 changes: 127 additions & 0 deletions airflow/serialization/json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any

from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding


def encode(value: Any, type_: Any) -> dict:
"""Encode value and type into a JSON-friendly dict."""
return {Encoding.VAR: value, Encoding.TYPE: type_}


_return_primitive = lambda value: value
Comment thread
turbaszek marked this conversation as resolved.
Outdated


def _serialize_list(list_: list) -> list:
return [serialize(item) for item in list_]


def _serialize_dict(dict_: dict) -> dict:
value = {k: serialize(v) for k, v in dict_.items()}
return encode(value=value, type_=DAT.DICT)


def _serialize_set(set_: set) -> dict:
value = [serialize(item) for item in set_]
return encode(value=value, type_=DAT.SET)


def _serialize_tuple(tuple_: tuple) -> dict:
value = [serialize(item) for item in tuple_]
return encode(value=value, type_=DAT.TUPLE)


def serialize(value):
"""Serialize a value so it can be stored as JSON.

The serialization protocol is:

(1) keep JSON supported types: primitives, dict, list;
(2) encode other types such as tuples and sets as
``{TYPE: 'foo', VAR: 'bar'}``
"""
serialization_function_by_type = {
type(None): _return_primitive,
int: _return_primitive,
bool: _return_primitive,
float: _return_primitive,
str: _return_primitive,
list: _serialize_list,
dict: _serialize_dict,
set: _serialize_set,
tuple: _serialize_tuple
}
try:
function = serialization_function_by_type[type(value)]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should support None as possible return value for XCom

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @turbaszek, for pointing this out - I just added a test & support for None.

except KeyError:
raise TypeError(f"Unable to serialize {type(value)}")
return function(value)


def _deserialize_list(list_: list) -> list:
return [deserialize(item) for item in list_]


def _deserialize_dict(dict_: dict) -> dict:
type_ = dict_[Encoding.TYPE]
value = dict_[Encoding.VAR]
if type_ == DAT.DICT:
value = {k: deserialize(v) for k, v in value.items()}
elif type_ == DAT.SET:
value = _deserialize_set(value)
elif type_ == DAT.TUPLE:
value = _deserialize_tuple(value)
else:
raise TypeError(f"Unable to deserialize dict of {Encoding.TYPE}: {type_}")
return value


def _deserialize_set(set_: set) -> set:
return {deserialize(item) for item in set_}


def _deserialize_tuple(tuple_: tuple) -> tuple:
return tuple(deserialize(item) for item in tuple_)


def deserialize(value):
"""Deserialize a JSON-compatible value into its original value.

The deserialization protocol is:

(1) keep JSON supported types: primitives, lists
(2) decode other types which might have been encoded in dicts using
``{TYPE: 'foo', VAR: 'bar'}``
"""

deserialization_function_by_type = {
type(None): _return_primitive,
int: _return_primitive,
bool: _return_primitive,
float: _return_primitive,
str: _return_primitive,
list: _deserialize_list,
dict: _deserialize_dict
}
try:
function = deserialization_function_by_type[type(value)]
except KeyError:
raise ValueError(f"Unable to deserialize {type(value)}")
return function(value)
54 changes: 54 additions & 0 deletions tests/models/test_xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,57 @@ def test_xcom_get_many(self):

for result in results:
self.assertEqual(result.value, json_obj)

@conf_vars({("core", "xcom_enable_pickling"): "False"})
def test_xcom_json_serialization_supports_set(self):
set_obj = set(["set-value1", "set-value2"])
execution_date = timezone.utcnow()
key = "xcom_test5"
dag_id = "test_dag5"
task_id = "test_task5"
XCom.set(key=key,
value=set_obj,
dag_id=dag_id,
task_id=task_id,
execution_date=execution_date)

ret_value = XCom.get_one(key=key,
dag_id=dag_id,
task_id=task_id,
execution_date=execution_date)

self.assertEqual(ret_value, set_obj)

session = settings.Session()
ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id,
XCom.task_id == task_id,
XCom.execution_date == execution_date
).first().value
self.assertEqual(ret_value, set_obj)

@conf_vars({("core", "xcom_enable_pickling"): "False"})
def test_xcom_json_serialization_supports_nested_sets(self):
set_obj = {"dict_key": set(["set-value"])}
execution_date = timezone.utcnow()
key = "xcom_test6"
dag_id = "test_dag6"
task_id = "test_task6"
XCom.set(key=key,
value=set_obj,
dag_id=dag_id,
task_id=task_id,
execution_date=execution_date)

ret_value = XCom.get_one(key=key,
dag_id=dag_id,
task_id=task_id,
execution_date=execution_date)

self.assertEqual(ret_value, set_obj)

session = settings.Session()
ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id,
XCom.task_id == task_id,
XCom.execution_date == execution_date
).first().value
self.assertEqual(ret_value, set_obj)
Loading