diff --git a/airflow/jobs/pydantic/__init__.py b/airflow/jobs/pydantic/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/jobs/pydantic/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/jobs/pydantic/base_job.py b/airflow/jobs/pydantic/base_job.py new file mode 100644 index 0000000000000..bad9aeca48dbc --- /dev/null +++ b/airflow/jobs/pydantic/base_job.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel as BaseModelPydantic + +from airflow.models.pydantic.taskinstance import TaskInstancePydantic + + +class BaseJobPydantic(BaseModelPydantic): + """Serializable representation of the BaseJob ORM SqlAlchemyModel used by internal API""" + + id: Optional[int] + dag_id: Optional[str] + state: Optional[str] + job_type: Optional[str] + start_date: Optional[datetime] + end_date: Optional[datetime] + latest_heartbeat: Optional[datetime] + executor_class: Optional[str] + hostname: Optional[str] + unixname: Optional[str] + task_instance: TaskInstancePydantic + + class Config: + """Make sure it deals automatically with ORM classes of SQL Alchemy""" + + orm_mode = True diff --git a/airflow/models/pydantic/__init__.py b/airflow/models/pydantic/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/models/pydantic/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/models/pydantic/dag_run.py b/airflow/models/pydantic/dag_run.py new file mode 100644 index 0000000000000..e2d44296b9c60 --- /dev/null +++ b/airflow/models/pydantic/dag_run.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from typing import List, Optional + +from pydantic import BaseModel as BaseModelPydantic + +from airflow.models.pydantic.dataset import DatasetEventPydantic + + +class DagRunPydantic(BaseModelPydantic): + """Serializable representation of the DagRun ORM SqlAlchemyModel used by internal API.""" + + id: int + dag_id: str + queued_at: Optional[datetime] + execution_date: datetime + start_date: Optional[datetime] + end_date: Optional[datetime] + state: str + run_id: Optional[str] + creating_job_id: Optional[int] + external_trigger: bool + run_type: str + data_interval_start: Optional[datetime] + data_interval_end: Optional[datetime] + last_scheduling_decision: Optional[datetime] + dag_hash: Optional[str] + updated_at: datetime + consumed_dataset_events: List[DatasetEventPydantic] + + class Config: + """Make sure it deals automatically with ORM classes of SQL Alchemy""" + + orm_mode = True diff --git a/airflow/models/pydantic/dataset.py b/airflow/models/pydantic/dataset.py new file mode 100644 index 0000000000000..39c552eea9234 --- /dev/null +++ b/airflow/models/pydantic/dataset.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import datetime +from typing import List, Optional + +from pydantic import BaseModel as BaseModelPydantic + + +class DagScheduleDatasetReferencePydantic(BaseModelPydantic): + """ + Serializable representation of the DagScheduleDatasetReference + ORM SqlAlchemyModel used by internal API. + """ + + dataset_id: int + dag_id: str + created_at: datetime + updated_at: datetime + + class Config: + """Make sure it deals automatically with ORM classes of SQL Alchemy""" + + orm_mode = True + + +class TaskOutletDatasetReferencePydantic(BaseModelPydantic): + """ + Serializable representation of the + TaskOutletDatasetReference ORM SqlAlchemyModel used by internal API. + """ + + dataset_id: int + dag_id = str + task_id = str + created_at = datetime + updated_at = datetime + + class Config: + """Make sure it deals automatically with ORM classes of SQL Alchemy""" + + orm_mode = True + + +class DatasetPydantic(BaseModelPydantic): + """Serializable representation of the Dataset ORM SqlAlchemyModel used by internal API.""" + + id: int + uri: str + extra: Optional[dict] + created_at: datetime + updated_at: datetime + is_orphaned: bool + + consuming_dags: List[DagScheduleDatasetReferencePydantic] + producing_tasks: List[TaskOutletDatasetReferencePydantic] + + class Config: + """Make sure it deals automatically with ORM classes of SQL Alchemy""" + + orm_mode = True + + +class DatasetEventPydantic(BaseModelPydantic): + """Serializable representation of the DatasetEvent ORM SqlAlchemyModel used by internal API.""" + + id: int + source_task_id: Optional[str] + source_dag_id: Optional[str] + source_run_id: Optional[str] + extra: Optional[dict] + source_map_index: int + timestamp: datetime + dataset: DatasetPydantic + + class Config: + """Make sure it deals automatically with ORM classes of SQL Alchemy""" + + orm_mode = True diff --git a/airflow/models/pydantic/taskinstance.py b/airflow/models/pydantic/taskinstance.py new file mode 100644 index 0000000000000..f174996d4d9c1 --- /dev/null +++ b/airflow/models/pydantic/taskinstance.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel as BaseModelPydantic + + +class TaskInstancePydantic(BaseModelPydantic): + """Serializable representation of the TaskInstance ORM SqlAlchemyModel used by internal API""" + + task_id: str + dag_id: str + run_id: str + map_index: str + start_date: Optional[datetime] + end_date: Optional[datetime] + duration: Optional[float] + state: Optional[str] + _try_number: int + max_tries: int + hostname: str + unixname: str + job_id: Optional[int] + pool: str + pool_slots: int + queue: str + priority_weight: Optional[int] + operator: str + queued_dttm: Optional[str] + queued_by_job_id: Optional[int] + pid: Optional[int] + updated_at: Optional[datetime] + external_executor_id: Optional[str] + trigger_id: Optional[int] + trigger_timeout: Optional[datetime] + next_method: Optional[str] + next_kwargs: Optional[dict] + run_as_user: Optional[str] + + class Config: + """Make sure it deals automatically with ORM classes of SQL Alchemy""" + + orm_mode = True diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index 838162649ae12..374d36988b299 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -36,7 +36,10 @@ from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.dataset import DatasetEvent from airflow.models.param import ParamsDict +from airflow.models.pydantic.dag_run import DagRunPydantic +from airflow.models.pydantic.dataset import DatasetEventPydantic from airflow.models.taskinstance import TaskInstance +from airflow.models.taskinstance_pydantic import TaskInstancePydantic from airflow.typing_compat import TypedDict KNOWN_CONTEXT_KEYS: set[str] @@ -57,7 +60,7 @@ class Context(TypedDict, total=False): conf: AirflowConfigParser conn: Any dag: DAG - dag_run: DagRun + dag_run: DagRun | DagRunPydantic data_interval_end: DateTime data_interval_start: DateTime ds: str @@ -82,14 +85,14 @@ class Context(TypedDict, total=False): prev_start_date_success: DateTime | None run_id: str task: BaseOperator - task_instance: TaskInstance + task_instance: TaskInstance | TaskInstancePydantic task_instance_key_str: str test_mode: bool templates_dict: Mapping[str, Any] | None - ti: TaskInstance + ti: TaskInstance | TaskInstancePydantic tomorrow_ds: str tomorrow_ds_nodash: str - triggering_dataset_events: Mapping[str, Collection[DatasetEvent]] + triggering_dataset_events: Mapping[str, Collection[DatasetEvent | DatasetEventPydantic]] ts: str ts_nodash: str ts_nodash_with_tz: str diff --git a/pyproject.toml b/pyproject.toml index 722d50de5394f..7baf4d4c6f003 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,13 @@ known-third-party = [ # needed for the test to work "tests/decorators/test_python.py" = ["I002"] +# The Pydantic representations of SqlAlchemy Models are not parsed well with Pydantic +# when __future__.annotations is used so we need to skip them from upgrading +"airflow/models/pydantic/taskinstance.py" = ["I002"] +"airflow/models/pydantic/dag_run.py" = ["I002"] +"airflow/models/pydantic/dataset.py" = ["I002"] +"airflow/jobs/pydantic/base_job.py" = ["I002"] + # Ignore pydoc style from these "*.pyi" = ["D"] "tests/*" = ["D"] diff --git a/setup.cfg b/setup.cfg index d0f8b2489f9b6..95c6b457e6428 100644 --- a/setup.cfg +++ b/setup.cfg @@ -122,6 +122,7 @@ install_requires = pendulum>=2.0 pluggy>=1.0 psutil>=4.2.0 + pydantic>=1.10.0 pygments>=2.0.1 pyjwt>=2.0.0 python-daemon>=3.0.0 diff --git a/tests/models/test_pydantic_models.py b/tests/models/test_pydantic_models.py new file mode 100644 index 0000000000000..bc36e0cda4245 --- /dev/null +++ b/tests/models/test_pydantic_models.py @@ -0,0 +1,151 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from pydantic import parse_raw_as + +from airflow.jobs.local_task_job import LocalTaskJob +from airflow.jobs.pydantic.base_job import BaseJobPydantic +from airflow.models.dataset import ( + DagScheduleDatasetReference, + DatasetEvent, + DatasetModel, + TaskOutletDatasetReference, +) +from airflow.models.pydantic.dag_run import DagRunPydantic +from airflow.models.pydantic.dataset import DatasetEventPydantic +from airflow.models.pydantic.taskinstance import TaskInstancePydantic +from airflow.utils import timezone +from airflow.utils.state import State +from airflow.utils.types import DagRunType +from tests.models import DEFAULT_DATE + + +def test_serializing_pydantic_task_instance(session, create_task_instance): + dag_id = "test-dag" + ti = create_task_instance(dag_id=dag_id, session=session) + ti.state = State.RUNNING + ti.next_kwargs = {"foo": "bar"} + session.commit() + + pydantic_task_instance = TaskInstancePydantic.from_orm(ti) + + json_string = pydantic_task_instance.json() + print(json_string) + + deserialized_model = parse_raw_as(TaskInstancePydantic, json_string) + assert deserialized_model.dag_id == dag_id + assert deserialized_model.state == State.RUNNING + assert deserialized_model.next_kwargs == {"foo": "bar"} + + +def test_serializing_pydantic_dagrun(session, create_task_instance): + dag_id = "test-dag" + ti = create_task_instance(dag_id=dag_id, session=session) + ti.dag_run.state = State.RUNNING + session.commit() + + pydantic_dag_run = DagRunPydantic.from_orm(ti.dag_run) + + json_string = pydantic_dag_run.json() + print(json_string) + + deserialized_model = parse_raw_as(DagRunPydantic, json_string) + assert deserialized_model.dag_id == dag_id + assert deserialized_model.state == State.RUNNING + + +def test_serializing_pydantic_local_task_job(session, create_task_instance): + dag_id = "test-dag" + ti = create_task_instance(dag_id=dag_id, session=session) + ltj = LocalTaskJob(task_instance=ti) + ltj.state = State.RUNNING + session.commit() + pydantic_job = BaseJobPydantic.from_orm(ltj) + + json_string = pydantic_job.json() + print(json_string) + + deserialized_model = parse_raw_as(BaseJobPydantic, json_string) + assert deserialized_model.dag_id == dag_id + assert deserialized_model.state == State.RUNNING + assert deserialized_model.task_instance.task_id == ti.task_id + + +def test_serializing_pydantic_dataset_event(session, create_task_instance, create_dummy_dag): + ds1 = DatasetModel(id=1, uri="one", extra={"foo": "bar"}) + ds2 = DatasetModel(id=2, uri="two") + + session.add_all([ds1, ds2]) + session.commit() + + # it's easier to fake a manual run here + dag, task1 = create_dummy_dag( + dag_id="test_triggering_dataset_events", + schedule=None, + start_date=DEFAULT_DATE, + task_id="test_context", + with_dagrun_type=DagRunType.MANUAL, + session=session, + ) + dr = dag.create_dagrun( + run_id="test2", + run_type=DagRunType.DATASET_TRIGGERED, + execution_date=timezone.utcnow(), + state=None, + session=session, + ) + ds1_event = DatasetEvent(dataset_id=1) + ds2_event_1 = DatasetEvent(dataset_id=2) + ds2_event_2 = DatasetEvent(dataset_id=2) + + DagScheduleDatasetReference(dag_id=dag.dag_id, dataset=ds1) + TaskOutletDatasetReference(task_id=task1.task_id, dag_id=dag.dag_id, dataset=ds1) + + dr.consumed_dataset_events.append(ds1_event) + dr.consumed_dataset_events.append(ds2_event_1) + dr.consumed_dataset_events.append(ds2_event_2) + session.commit() + + print(ds2_event_2.dataset.consuming_dags) + pydantic_dse1 = DatasetEventPydantic.from_orm(ds1_event) + json_string1 = pydantic_dse1.json() + print(json_string1) + + pydantic_dse2 = DatasetEventPydantic.from_orm(ds2_event_1) + json_string2 = pydantic_dse2.json() + print(json_string2) + + pydantic_dag_run = DagRunPydantic.from_orm(dr) + json_string_dr = pydantic_dag_run.json() + print(json_string_dr) + + deserialized_model1 = parse_raw_as(DatasetEventPydantic, json_string1) + assert deserialized_model1.dataset.id == 1 + assert deserialized_model1.dataset.uri == "one" + assert len(deserialized_model1.dataset.consuming_dags) == 1 + assert len(deserialized_model1.dataset.producing_tasks) == 1 + + deserialized_model2 = parse_raw_as(DatasetEventPydantic, json_string2) + assert deserialized_model2.dataset.id == 2 + assert deserialized_model2.dataset.uri == "two" + assert len(deserialized_model2.dataset.consuming_dags) == 0 + assert len(deserialized_model2.dataset.producing_tasks) == 0 + + deserialized_dr = parse_raw_as(DagRunPydantic, json_string_dr) + assert len(deserialized_dr.consumed_dataset_events) == 3