From 0463b3fad5850ea162f193ad25aa8c2d8327896d Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Sun, 26 Feb 2023 15:38:46 +0100 Subject: [PATCH] Add Pydantic-powered ORM models serialization for internal API. Add basic serialization capabilities for the ORM SqlAlchemy models that we use on the client side of the Internal API. Serializing the whole ORM models is rather complex, therefore it seems much more reasonable to convert the ORM models into serializable form and use them - rather than the SQLAlchemy models. There are just a handful of those models that we need to serialize, and it is important to maintain typing of the fields in the objects for MyPy verification so we can allow some level of duplication and redefine the models as pure Python objects. We only need one-way converstion (from database models to Python models), because all the DB operations and modifications of the Database entries will be done in the internal API server, so the server side of any method will be able to use primary key stored in the serializable object, to retrieve the actual DB model to update. We also need to serialization to work both way - an easy way to convert such Python classees to json and back - including validation. We could serialize those models manually, but this would be quite an overhead to develop and maintain - therefore we are harnessing the power of Pydantic, that has already ORM mapping to plain Python (Pydantic) classes built in. This PR implements definition of the Pydantic classes and tests for the classes testing: * conversion of the ORM models to Pydantic objects * serialization of the Pydantic classes to json * deserialization of the json-serialized classes to Pydantic objects --- airflow/jobs/pydantic/__init__.py | 16 +++ airflow/jobs/pydantic/base_job.py | 44 +++++++ airflow/models/pydantic/__init__.py | 16 +++ airflow/models/pydantic/dag_run.py | 50 ++++++++ airflow/models/pydantic/dataset.py | 92 +++++++++++++++ airflow/models/pydantic/taskinstance.py | 59 +++++++++ airflow/utils/context.pyi | 11 +- pyproject.toml | 7 ++ setup.cfg | 1 + tests/models/test_pydantic_models.py | 151 ++++++++++++++++++++++++ 10 files changed, 443 insertions(+), 4 deletions(-) create mode 100644 airflow/jobs/pydantic/__init__.py create mode 100644 airflow/jobs/pydantic/base_job.py create mode 100644 airflow/models/pydantic/__init__.py create mode 100644 airflow/models/pydantic/dag_run.py create mode 100644 airflow/models/pydantic/dataset.py create mode 100644 airflow/models/pydantic/taskinstance.py create mode 100644 tests/models/test_pydantic_models.py diff --git a/airflow/jobs/pydantic/__init__.py b/airflow/jobs/pydantic/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/jobs/pydantic/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/jobs/pydantic/base_job.py b/airflow/jobs/pydantic/base_job.py new file mode 100644 index 0000000000000..bad9aeca48dbc --- /dev/null +++ b/airflow/jobs/pydantic/base_job.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel as BaseModelPydantic + +from airflow.models.pydantic.taskinstance import TaskInstancePydantic + + +class BaseJobPydantic(BaseModelPydantic): + """Serializable representation of the BaseJob ORM SqlAlchemyModel used by internal API""" + + id: Optional[int] + dag_id: Optional[str] + state: Optional[str] + job_type: Optional[str] + start_date: Optional[datetime] + end_date: Optional[datetime] + latest_heartbeat: Optional[datetime] + executor_class: Optional[str] + hostname: Optional[str] + unixname: Optional[str] + task_instance: TaskInstancePydantic + + class Config: + """Make sure it deals automatically with ORM classes of SQL Alchemy""" + + orm_mode = True diff --git a/airflow/models/pydantic/__init__.py b/airflow/models/pydantic/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/models/pydantic/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/models/pydantic/dag_run.py b/airflow/models/pydantic/dag_run.py new file mode 100644 index 0000000000000..e2d44296b9c60 --- /dev/null +++ b/airflow/models/pydantic/dag_run.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from typing import List, Optional + +from pydantic import BaseModel as BaseModelPydantic + +from airflow.models.pydantic.dataset import DatasetEventPydantic + + +class DagRunPydantic(BaseModelPydantic): + """Serializable representation of the DagRun ORM SqlAlchemyModel used by internal API.""" + + id: int + dag_id: str + queued_at: Optional[datetime] + execution_date: datetime + start_date: Optional[datetime] + end_date: Optional[datetime] + state: str + run_id: Optional[str] + creating_job_id: Optional[int] + external_trigger: bool + run_type: str + data_interval_start: Optional[datetime] + data_interval_end: Optional[datetime] + last_scheduling_decision: Optional[datetime] + dag_hash: Optional[str] + updated_at: datetime + consumed_dataset_events: List[DatasetEventPydantic] + + class Config: + """Make sure it deals automatically with ORM classes of SQL Alchemy""" + + orm_mode = True diff --git a/airflow/models/pydantic/dataset.py b/airflow/models/pydantic/dataset.py new file mode 100644 index 0000000000000..39c552eea9234 --- /dev/null +++ b/airflow/models/pydantic/dataset.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import datetime +from typing import List, Optional + +from pydantic import BaseModel as BaseModelPydantic + + +class DagScheduleDatasetReferencePydantic(BaseModelPydantic): + """ + Serializable representation of the DagScheduleDatasetReference + ORM SqlAlchemyModel used by internal API. + """ + + dataset_id: int + dag_id: str + created_at: datetime + updated_at: datetime + + class Config: + """Make sure it deals automatically with ORM classes of SQL Alchemy""" + + orm_mode = True + + +class TaskOutletDatasetReferencePydantic(BaseModelPydantic): + """ + Serializable representation of the + TaskOutletDatasetReference ORM SqlAlchemyModel used by internal API. + """ + + dataset_id: int + dag_id = str + task_id = str + created_at = datetime + updated_at = datetime + + class Config: + """Make sure it deals automatically with ORM classes of SQL Alchemy""" + + orm_mode = True + + +class DatasetPydantic(BaseModelPydantic): + """Serializable representation of the Dataset ORM SqlAlchemyModel used by internal API.""" + + id: int + uri: str + extra: Optional[dict] + created_at: datetime + updated_at: datetime + is_orphaned: bool + + consuming_dags: List[DagScheduleDatasetReferencePydantic] + producing_tasks: List[TaskOutletDatasetReferencePydantic] + + class Config: + """Make sure it deals automatically with ORM classes of SQL Alchemy""" + + orm_mode = True + + +class DatasetEventPydantic(BaseModelPydantic): + """Serializable representation of the DatasetEvent ORM SqlAlchemyModel used by internal API.""" + + id: int + source_task_id: Optional[str] + source_dag_id: Optional[str] + source_run_id: Optional[str] + extra: Optional[dict] + source_map_index: int + timestamp: datetime + dataset: DatasetPydantic + + class Config: + """Make sure it deals automatically with ORM classes of SQL Alchemy""" + + orm_mode = True diff --git a/airflow/models/pydantic/taskinstance.py b/airflow/models/pydantic/taskinstance.py new file mode 100644 index 0000000000000..f174996d4d9c1 --- /dev/null +++ b/airflow/models/pydantic/taskinstance.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel as BaseModelPydantic + + +class TaskInstancePydantic(BaseModelPydantic): + """Serializable representation of the TaskInstance ORM SqlAlchemyModel used by internal API""" + + task_id: str + dag_id: str + run_id: str + map_index: str + start_date: Optional[datetime] + end_date: Optional[datetime] + duration: Optional[float] + state: Optional[str] + _try_number: int + max_tries: int + hostname: str + unixname: str + job_id: Optional[int] + pool: str + pool_slots: int + queue: str + priority_weight: Optional[int] + operator: str + queued_dttm: Optional[str] + queued_by_job_id: Optional[int] + pid: Optional[int] + updated_at: Optional[datetime] + external_executor_id: Optional[str] + trigger_id: Optional[int] + trigger_timeout: Optional[datetime] + next_method: Optional[str] + next_kwargs: Optional[dict] + run_as_user: Optional[str] + + class Config: + """Make sure it deals automatically with ORM classes of SQL Alchemy""" + + orm_mode = True diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index 838162649ae12..374d36988b299 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -36,7 +36,10 @@ from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.dataset import DatasetEvent from airflow.models.param import ParamsDict +from airflow.models.pydantic.dag_run import DagRunPydantic +from airflow.models.pydantic.dataset import DatasetEventPydantic from airflow.models.taskinstance import TaskInstance +from airflow.models.taskinstance_pydantic import TaskInstancePydantic from airflow.typing_compat import TypedDict KNOWN_CONTEXT_KEYS: set[str] @@ -57,7 +60,7 @@ class Context(TypedDict, total=False): conf: AirflowConfigParser conn: Any dag: DAG - dag_run: DagRun + dag_run: DagRun | DagRunPydantic data_interval_end: DateTime data_interval_start: DateTime ds: str @@ -82,14 +85,14 @@ class Context(TypedDict, total=False): prev_start_date_success: DateTime | None run_id: str task: BaseOperator - task_instance: TaskInstance + task_instance: TaskInstance | TaskInstancePydantic task_instance_key_str: str test_mode: bool templates_dict: Mapping[str, Any] | None - ti: TaskInstance + ti: TaskInstance | TaskInstancePydantic tomorrow_ds: str tomorrow_ds_nodash: str - triggering_dataset_events: Mapping[str, Collection[DatasetEvent]] + triggering_dataset_events: Mapping[str, Collection[DatasetEvent | DatasetEventPydantic]] ts: str ts_nodash: str ts_nodash_with_tz: str diff --git a/pyproject.toml b/pyproject.toml index 722d50de5394f..7baf4d4c6f003 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,13 @@ known-third-party = [ # needed for the test to work "tests/decorators/test_python.py" = ["I002"] +# The Pydantic representations of SqlAlchemy Models are not parsed well with Pydantic +# when __future__.annotations is used so we need to skip them from upgrading +"airflow/models/pydantic/taskinstance.py" = ["I002"] +"airflow/models/pydantic/dag_run.py" = ["I002"] +"airflow/models/pydantic/dataset.py" = ["I002"] +"airflow/jobs/pydantic/base_job.py" = ["I002"] + # Ignore pydoc style from these "*.pyi" = ["D"] "tests/*" = ["D"] diff --git a/setup.cfg b/setup.cfg index d0f8b2489f9b6..95c6b457e6428 100644 --- a/setup.cfg +++ b/setup.cfg @@ -122,6 +122,7 @@ install_requires = pendulum>=2.0 pluggy>=1.0 psutil>=4.2.0 + pydantic>=1.10.0 pygments>=2.0.1 pyjwt>=2.0.0 python-daemon>=3.0.0 diff --git a/tests/models/test_pydantic_models.py b/tests/models/test_pydantic_models.py new file mode 100644 index 0000000000000..bc36e0cda4245 --- /dev/null +++ b/tests/models/test_pydantic_models.py @@ -0,0 +1,151 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from pydantic import parse_raw_as + +from airflow.jobs.local_task_job import LocalTaskJob +from airflow.jobs.pydantic.base_job import BaseJobPydantic +from airflow.models.dataset import ( + DagScheduleDatasetReference, + DatasetEvent, + DatasetModel, + TaskOutletDatasetReference, +) +from airflow.models.pydantic.dag_run import DagRunPydantic +from airflow.models.pydantic.dataset import DatasetEventPydantic +from airflow.models.pydantic.taskinstance import TaskInstancePydantic +from airflow.utils import timezone +from airflow.utils.state import State +from airflow.utils.types import DagRunType +from tests.models import DEFAULT_DATE + + +def test_serializing_pydantic_task_instance(session, create_task_instance): + dag_id = "test-dag" + ti = create_task_instance(dag_id=dag_id, session=session) + ti.state = State.RUNNING + ti.next_kwargs = {"foo": "bar"} + session.commit() + + pydantic_task_instance = TaskInstancePydantic.from_orm(ti) + + json_string = pydantic_task_instance.json() + print(json_string) + + deserialized_model = parse_raw_as(TaskInstancePydantic, json_string) + assert deserialized_model.dag_id == dag_id + assert deserialized_model.state == State.RUNNING + assert deserialized_model.next_kwargs == {"foo": "bar"} + + +def test_serializing_pydantic_dagrun(session, create_task_instance): + dag_id = "test-dag" + ti = create_task_instance(dag_id=dag_id, session=session) + ti.dag_run.state = State.RUNNING + session.commit() + + pydantic_dag_run = DagRunPydantic.from_orm(ti.dag_run) + + json_string = pydantic_dag_run.json() + print(json_string) + + deserialized_model = parse_raw_as(DagRunPydantic, json_string) + assert deserialized_model.dag_id == dag_id + assert deserialized_model.state == State.RUNNING + + +def test_serializing_pydantic_local_task_job(session, create_task_instance): + dag_id = "test-dag" + ti = create_task_instance(dag_id=dag_id, session=session) + ltj = LocalTaskJob(task_instance=ti) + ltj.state = State.RUNNING + session.commit() + pydantic_job = BaseJobPydantic.from_orm(ltj) + + json_string = pydantic_job.json() + print(json_string) + + deserialized_model = parse_raw_as(BaseJobPydantic, json_string) + assert deserialized_model.dag_id == dag_id + assert deserialized_model.state == State.RUNNING + assert deserialized_model.task_instance.task_id == ti.task_id + + +def test_serializing_pydantic_dataset_event(session, create_task_instance, create_dummy_dag): + ds1 = DatasetModel(id=1, uri="one", extra={"foo": "bar"}) + ds2 = DatasetModel(id=2, uri="two") + + session.add_all([ds1, ds2]) + session.commit() + + # it's easier to fake a manual run here + dag, task1 = create_dummy_dag( + dag_id="test_triggering_dataset_events", + schedule=None, + start_date=DEFAULT_DATE, + task_id="test_context", + with_dagrun_type=DagRunType.MANUAL, + session=session, + ) + dr = dag.create_dagrun( + run_id="test2", + run_type=DagRunType.DATASET_TRIGGERED, + execution_date=timezone.utcnow(), + state=None, + session=session, + ) + ds1_event = DatasetEvent(dataset_id=1) + ds2_event_1 = DatasetEvent(dataset_id=2) + ds2_event_2 = DatasetEvent(dataset_id=2) + + DagScheduleDatasetReference(dag_id=dag.dag_id, dataset=ds1) + TaskOutletDatasetReference(task_id=task1.task_id, dag_id=dag.dag_id, dataset=ds1) + + dr.consumed_dataset_events.append(ds1_event) + dr.consumed_dataset_events.append(ds2_event_1) + dr.consumed_dataset_events.append(ds2_event_2) + session.commit() + + print(ds2_event_2.dataset.consuming_dags) + pydantic_dse1 = DatasetEventPydantic.from_orm(ds1_event) + json_string1 = pydantic_dse1.json() + print(json_string1) + + pydantic_dse2 = DatasetEventPydantic.from_orm(ds2_event_1) + json_string2 = pydantic_dse2.json() + print(json_string2) + + pydantic_dag_run = DagRunPydantic.from_orm(dr) + json_string_dr = pydantic_dag_run.json() + print(json_string_dr) + + deserialized_model1 = parse_raw_as(DatasetEventPydantic, json_string1) + assert deserialized_model1.dataset.id == 1 + assert deserialized_model1.dataset.uri == "one" + assert len(deserialized_model1.dataset.consuming_dags) == 1 + assert len(deserialized_model1.dataset.producing_tasks) == 1 + + deserialized_model2 = parse_raw_as(DatasetEventPydantic, json_string2) + assert deserialized_model2.dataset.id == 2 + assert deserialized_model2.dataset.uri == "two" + assert len(deserialized_model2.dataset.consuming_dags) == 0 + assert len(deserialized_model2.dataset.producing_tasks) == 0 + + deserialized_dr = parse_raw_as(DagRunPydantic, json_string_dr) + assert len(deserialized_dr.consumed_dataset_events) == 3