diff --git a/airflow/serialization/serializers/timezone.py b/airflow/serialization/serializers/timezone.py index b55b51610b41b..5d3b940cd78f4 100644 --- a/airflow/serialization/serializers/timezone.py +++ b/airflow/serialization/serializers/timezone.py @@ -17,17 +17,22 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING +import datetime +from typing import TYPE_CHECKING, Any, cast from airflow.utils.module_loading import qualname if TYPE_CHECKING: - from pendulum.tz.timezone import Timezone - from airflow.serialization.serde import U -serializers = ["pendulum.tz.timezone.FixedTimezone", "pendulum.tz.timezone.Timezone"] +serializers = [ + "pendulum.tz.timezone.FixedTimezone", + "pendulum.tz.timezone.Timezone", + "zoneinfo.ZoneInfo", + "backports.zoneinfo.ZoneInfo", +] + deserializers = serializers __version__ = 1 @@ -43,21 +48,26 @@ def serialize(o: object) -> tuple[U, str, int, bool]: 0 without the special case), but passing 0 into ``pendulum.timezone`` does not give us UTC (but ``+00:00``). """ - from pendulum.tz.timezone import FixedTimezone, Timezone + from pendulum.tz.timezone import FixedTimezone name = qualname(o) + if isinstance(o, FixedTimezone): if o.offset == 0: return "UTC", name, __version__, True return o.offset, name, __version__, True - if isinstance(o, Timezone): - return o.name, name, __version__, True + tz_name = _get_tzinfo_name(cast(datetime.tzinfo, o)) + if tz_name is not None: + return tz_name, name, __version__, True + + if cast(datetime.tzinfo, o).utcoffset(None) == datetime.timedelta(0): + return "UTC", qualname(FixedTimezone), __version__, True return "", "", 0, False -def deserialize(classname: str, version: int, data: object) -> Timezone: +def deserialize(classname: str, version: int, data: object) -> Any: from pendulum.tz import fixed_timezone, timezone if not isinstance(data, (str, int)): @@ -69,4 +79,36 @@ def deserialize(classname: str, version: int, data: object) -> Timezone: if isinstance(data, int): return fixed_timezone(data) + if classname == "zoneinfo.ZoneInfo": + from zoneinfo import ZoneInfo + + return ZoneInfo(data) + + if classname == "backports.zoneinfo.ZoneInfo": + # python version might have been upgraded, so we need to check + try: + from backports.zoneinfo import ZoneInfo + except ImportError: + from zoneinfo import ZoneInfo + + return ZoneInfo(data) + return timezone(data) + + +# ported from pendulum.tz.timezone._get_tzinfo_name +def _get_tzinfo_name(tzinfo: datetime.tzinfo | None) -> str | None: + if tzinfo is None: + return None + + if hasattr(tzinfo, "key"): + # zoneinfo timezone + return tzinfo.key + elif hasattr(tzinfo, "name"): + # Pendulum timezone + return tzinfo.name + elif hasattr(tzinfo, "zone"): + # pytz timezone + return tzinfo.zone # type: ignore[no-any-return] + + return None diff --git a/setup.py b/setup.py index d5b8c333d0c93..c69fffc7ae89f 100644 --- a/setup.py +++ b/setup.py @@ -462,6 +462,7 @@ def write_version(filename: str = str(AIRFLOW_SOURCES_ROOT / "airflow" / "git_ve _devel_only_tests = [ "aioresponses", + "backports.zoneinfo>=0.2.1;python_version<'3.9'", "beautifulsoup4>=4.7.1", "coverage>=7.2", "pytest", diff --git a/tests/serialization/serializers/test_serializers.py b/tests/serialization/serializers/test_serializers.py index 26e4ecea0eac1..2c9c94e5c81cb 100644 --- a/tests/serialization/serializers/test_serializers.py +++ b/tests/serialization/serializers/test_serializers.py @@ -22,11 +22,18 @@ import numpy as np import pendulum.tz import pytest +from dateutil.tz import tzutc from pendulum import DateTime +from airflow import PY39 from airflow.models.param import Param, ParamsDict from airflow.serialization.serde import DATA, deserialize, serialize +if PY39: + from zoneinfo import ZoneInfo +else: + from backports.zoneinfo import ZoneInfo + class TestSerializers: def test_datetime(self): @@ -62,8 +69,17 @@ def test_datetime(self): d = deserialize(s) assert i.timestamp() == d.timestamp() - def test_deserialize_datetime_v1(self): + i = DateTime(2022, 7, 10, tzinfo=tzutc()) + s = serialize(i) + d = deserialize(s) + assert i.timestamp() == d.timestamp() + i = DateTime(2022, 7, 10, tzinfo=ZoneInfo("Europe/Paris")) + s = serialize(i) + d = deserialize(s) + assert i.timestamp() == d.timestamp() + + def test_deserialize_datetime_v1(self): s = { "__classname__": "pendulum.datetime.DateTime", "__version__": 1,