From de4bf337daa407c98c31cbf5babeb2662248d1b9 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Fri, 7 Mar 2025 15:54:00 +0800 Subject: [PATCH 1/2] Move asset evaluation logic out of SDK Asset evaluation is only done in the scheduler, and requires the database in various cases. It is better to split it out into a dedicated class in Airflow core. --- airflow/assets/evaluation.py | 78 +++++++ airflow/models/asset.py | 6 +- airflow/models/dag.py | 5 +- .../airflow/sdk/definitions/asset/__init__.py | 48 +---- .../sdk/definitions/asset/decorators.py | 9 +- .../tests/task_sdk/definitions/test_asset.py | 156 +------------- tests/assets/test_evaluation.py | 199 ++++++++++++++++++ tests/models/test_asset.py | 4 +- tests/timetables/test_assets_timetable.py | 5 +- 9 files changed, 305 insertions(+), 205 deletions(-) create mode 100644 airflow/assets/evaluation.py create mode 100644 tests/assets/test_evaluation.py diff --git a/airflow/assets/evaluation.py b/airflow/assets/evaluation.py new file mode 100644 index 0000000000000..b1a877d4d542b --- /dev/null +++ b/airflow/assets/evaluation.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING + +import attrs + +from airflow.models.asset import expand_alias_to_assets, resolve_ref_to_asset +from airflow.sdk.definitions.asset import ( + Asset, + AssetAlias, + AssetBooleanCondition, + AssetRef, + AssetUniqueKey, + BaseAsset, +) +from airflow.sdk.definitions.asset.decorators import MultiAssetDefinition + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + +@attrs.define +class AssetEvaluator: + """Evaluates whether an asset-like object has been satisfied.""" + + _session: Session + + def _resolve_asset_ref(self, o: AssetRef) -> Asset | None: + asset = resolve_ref_to_asset(**attrs.asdict(o), session=self._session) + return asset.to_public() if asset else None + + def _resolve_asset_alias(self, o: AssetAlias) -> list[Asset]: + asset_models = expand_alias_to_assets(o.name, session=self._session) + return [m.to_public() for m in asset_models] + + @functools.singledispatchmethod + def run(self, o: BaseAsset, statuses: dict[AssetUniqueKey, bool]) -> bool: + raise NotImplementedError(f"can not evaluate {o!r}") + + @run.register + def _(self, o: Asset, statuses: dict[AssetUniqueKey, bool]) -> bool: + return statuses.get(AssetUniqueKey.from_asset(o), False) + + @run.register + def _(self, o: AssetRef, statuses: dict[AssetUniqueKey, bool]) -> bool: + if asset := self._resolve_asset_ref(o): + return self.run(asset, statuses) + return False + + @run.register + def _(self, o: AssetAlias, statuses: dict[AssetUniqueKey, bool]) -> bool: + return any(self.run(x, statuses) for x in self._resolve_asset_alias(o)) + + @run.register + def _(self, o: AssetBooleanCondition, statuses: dict[AssetUniqueKey, bool]) -> bool: + return o.agg_func(self.run(x, statuses) for x in o.objects) + + @run.register + def _(self, o: MultiAssetDefinition, statuses: dict[AssetUniqueKey, bool]) -> bool: + return all(self.run(x, statuses) for x in o.iter_outlets()) diff --git a/airflow/models/asset.py b/airflow/models/asset.py index 212a0b3a84c6c..5ec0b2e977cba 100644 --- a/airflow/models/asset.py +++ b/airflow/models/asset.py @@ -70,14 +70,14 @@ def fetch_active_assets_by_uri(uris: Iterable[str], session: Session) -> dict[st } -def expand_alias_to_assets(alias_name: str, session: Session) -> Iterable[AssetModel]: +def expand_alias_to_assets(alias_name: str, *, session: Session) -> Iterable[AssetModel]: """Expand asset alias to resolved assets.""" asset_alias_obj = session.scalar( select(AssetAliasModel).where(AssetAliasModel.name == alias_name).limit(1) ) if asset_alias_obj: - return list(asset_alias_obj.assets) - return [] + return iter(asset_alias_obj.assets) + return iter(()) def resolve_ref_to_asset( diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 16cbf28d64706..cbb73e45f8402 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -68,6 +68,7 @@ from sqlalchemy.sql import Select, expression from airflow import settings, utils +from airflow.assets.evaluation import AssetEvaluator from airflow.configuration import conf as airflow_conf, secrets_backend_list from airflow.exceptions import ( AirflowException, @@ -2323,12 +2324,14 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, dateti """ from airflow.models.serialized_dag import SerializedDagModel + evaluator = AssetEvaluator(session) + def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict[AssetUniqueKey, bool]) -> bool | None: # if dag was serialized before 2.9 and we *just* upgraded, # we may be dealing with old version. In that case, # just wait for the dag to be reserialized. try: - return cond.evaluate(statuses, session=session) + return evaluator.run(cond, statuses) except AttributeError: log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id) return None diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py index 11b7a1b2daefe..5a01a745d9cad 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -17,7 +17,6 @@ from __future__ import annotations -import contextlib import logging import operator import os @@ -34,8 +33,6 @@ from collections.abc import Iterable, Iterator from urllib.parse import SplitResult - from sqlalchemy.orm import Session - from airflow.models.asset import AssetModel from airflow.serialization.serialized_objects import SerializedAssetWatcher from airflow.triggers.base import BaseEventTrigger @@ -233,9 +230,6 @@ def as_expression(self) -> Any: """ raise NotImplementedError - def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool: - raise NotImplementedError - def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: raise NotImplementedError @@ -442,9 +436,6 @@ def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: def iter_asset_refs(self) -> Iterator[AssetRef]: return iter(()) - def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool: - return statuses.get(AssetUniqueKey.from_asset(self), False) - def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: """ Iterate an asset as dag dependency. @@ -489,19 +480,6 @@ def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: def iter_asset_refs(self) -> Iterator[AssetRef]: yield self - def _resolve_asset(self, *, session: Session | None = None) -> Asset | None: - from airflow.models.asset import resolve_ref_to_asset - from airflow.utils.session import create_session - - with contextlib.nullcontext(session) if session else create_session() as session: - asset = resolve_ref_to_asset(**attrs.asdict(self), session=session) - return asset.to_public() if asset else None - - def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool: - if asset := self._resolve_asset(session=session): - return asset.evaluate(statuses=statuses, session=session) - return False - def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]: (dependency_id,) = attrs.astuple(self) if asset := self._resolve_asset(): @@ -553,14 +531,6 @@ class AssetAlias(BaseAsset): name: str = attrs.field(validator=_validate_non_empty_identifier) group: str = attrs.field(kw_only=True, default="asset", validator=_validate_identifier) - def _resolve_assets(self, session: Session | None = None) -> list[Asset]: - from airflow.models.asset import expand_alias_to_assets - from airflow.utils.session import create_session - - with contextlib.nullcontext(session) if session else create_session() as session: - asset_models = expand_alias_to_assets(self.name, session) - return [m.to_public() for m in asset_models] - def as_expression(self) -> Any: """ Serialize the asset alias into its scheduling expression. @@ -569,9 +539,6 @@ def as_expression(self) -> Any: """ return {"alias": {"name": self.name, "group": self.group}} - def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool: - return any(x.evaluate(statuses=statuses, session=session) for x in self._resolve_assets(session)) - def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: return iter(()) @@ -613,8 +580,12 @@ def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterat ) -class _AssetBooleanCondition(BaseAsset): - """Base class for asset boolean logic.""" +class AssetBooleanCondition(BaseAsset): + """ + Base class for asset boolean logic. + + :meta private: + """ agg_func: Callable[[Iterable], bool] @@ -623,9 +594,6 @@ def __init__(self, *objects: BaseAsset) -> None: raise TypeError("expect asset expressions in condition") self.objects = objects - def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool: - return self.agg_func(x.evaluate(statuses=statuses, session=session) for x in self.objects) - def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: for o in self.objects: yield from o.iter_assets() @@ -648,7 +616,7 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe yield from obj.iter_dag_dependencies(source=source, target=target) -class AssetAny(_AssetBooleanCondition): +class AssetAny(AssetBooleanCondition): """Use to combine assets schedule references in an "or" relationship.""" agg_func = any @@ -671,7 +639,7 @@ def as_expression(self) -> dict[str, Any]: return {"any": [o.as_expression() for o in self.objects]} -class AssetAll(_AssetBooleanCondition): +class AssetAll(AssetBooleanCondition): """Use to combine assets schedule references in an "and" relationship.""" agg_func = all diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index c70a224858b1b..77ab57074bbfe 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -28,8 +28,6 @@ if TYPE_CHECKING: from collections.abc import Callable, Collection, Iterator, Mapping - from sqlalchemy.orm import Session - from airflow.io.path import ObjectStoragePath from airflow.sdk.definitions.asset import AssetAlias, AssetUniqueKey from airflow.sdk.definitions.dag import DAG, DagStateChangeCallback, ScheduleArg @@ -122,9 +120,6 @@ def __attrs_post_init__(self) -> None: with self._source.create_dag(dag_id=self._function.__name__): _AssetMainOperator.from_definition(self) - def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: Session | None = None) -> bool: - return all(o.evaluate(statuses=statuses, session=session) for o in self._source.outlets) - def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: for o in self._source.outlets: yield from o.iter_assets() @@ -141,6 +136,10 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe for obj in self._source.outlets: yield from obj.iter_dag_dependencies(source=source, target=target) + def iter_outlets(self) -> Iterator[BaseAsset]: + """For asset evaluation in the scheduler.""" + return iter(self._source.outlets) + @attrs.define(kw_only=True) class _DAGFactory: diff --git a/task-sdk/tests/task_sdk/definitions/test_asset.py b/task-sdk/tests/task_sdk/definitions/test_asset.py index 767cd9e1be714..1637cffac61be 100644 --- a/task-sdk/tests/task_sdk/definitions/test_asset.py +++ b/task-sdk/tests/task_sdk/definitions/test_asset.py @@ -37,7 +37,7 @@ _sanitize_uri, ) from airflow.sdk.definitions.dag import DAG -from airflow.serialization.serialized_objects import BaseSerialization, SerializedDAG +from airflow.serialization.serialized_objects import SerializedDAG ASSET_MODULE_PATH = "airflow.sdk.definitions.asset" @@ -185,18 +185,6 @@ def test_asset_iter_asset_aliases(): ] -@pytest.mark.parametrize( - "statuses, result", - [ - ({AssetUniqueKey.from_asset(asset1): True}, True), - ({AssetUniqueKey.from_asset(asset1): False}, False), - ({}, False), - ], -) -def test_asset_evaluate(statuses, result): - assert asset1.evaluate(statuses) is result - - def test_asset_any_operations(): result_or = (asset1 | asset2) | asset3 assert isinstance(result_or, AssetAny) @@ -212,116 +200,6 @@ def test_asset_all_operations(): assert isinstance(result_and, AssetAll) -@pytest.mark.parametrize( - "condition, statuses, result", - [ - ( - AssetAny(asset1, asset2), - {AssetUniqueKey.from_asset(asset1): False, AssetUniqueKey.from_asset(asset2): True}, - True, - ), - ( - AssetAll(asset1, asset2), - {AssetUniqueKey.from_asset(asset1): True, AssetUniqueKey.from_asset(asset2): False}, - False, - ), - ], -) -def test_assset_boolean_condition_evaluate_iter(condition, statuses, result): - """ - Tests _AssetBooleanCondition's evaluate and iter_assets methods through AssetAny and AssetAll. - Ensures AssetAny evaluate returns True with any true condition, AssetAll evaluate returns False if - any condition is false, and both classes correctly iterate over assets without duplication. - """ - assert condition.evaluate(statuses) is result - assert dict(condition.iter_assets()) == { - AssetUniqueKey("asset-1", "s3://bucket1/data1"): asset1, - AssetUniqueKey("asset-2", "s3://bucket2/data2"): asset2, - } - - -@pytest.mark.parametrize( - "inputs, scenario, expected", - [ - # Scenarios for AssetAny - ((True, True, True), "any", True), - ((True, True, False), "any", True), - ((True, False, True), "any", True), - ((True, False, False), "any", True), - ((False, False, True), "any", True), - ((False, True, False), "any", True), - ((False, True, True), "any", True), - ((False, False, False), "any", False), - # Scenarios for AssetAll - ((True, True, True), "all", True), - ((True, True, False), "all", False), - ((True, False, True), "all", False), - ((True, False, False), "all", False), - ((False, False, True), "all", False), - ((False, True, False), "all", False), - ((False, True, True), "all", False), - ((False, False, False), "all", False), - ], -) -def test_asset_logical_conditions_evaluation_and_serialization(inputs, scenario, expected): - class_ = AssetAny if scenario == "any" else AssetAll - assets = [Asset(uri=f"s3://abc/{i}", name=f"asset_{i}") for i in range(123, 126)] - condition = class_(*assets) - - statuses = {AssetUniqueKey.from_asset(asset): status for asset, status in zip(assets, inputs)} - assert ( - condition.evaluate(statuses) == expected - ), f"Condition evaluation failed for inputs {inputs} and scenario '{scenario}'" - - # Serialize and deserialize the condition to test persistence - serialized = BaseSerialization.serialize(condition) - deserialized = BaseSerialization.deserialize(serialized) - assert deserialized.evaluate(statuses) == expected, "Serialization round-trip failed" - - -@pytest.mark.parametrize( - "status_values, expected_evaluation", - [ - ( - (False, True, True), - False, - ), # AssetAll requires all conditions to be True, but asset1 is False - ((True, True, True), True), # All conditions are True - ( - (True, False, True), - True, - ), # asset1 is True, and AssetAny condition (asset2 or asset3 being True) is met - ( - (True, False, False), - False, - ), # asset1 is True, but neither asset2 nor asset3 meet the AssetAny condition - ], -) -def test_nested_asset_conditions_with_serialization(status_values, expected_evaluation): - # Define assets - asset1 = Asset(uri="s3://abc/123") - asset2 = Asset(uri="s3://abc/124") - asset3 = Asset(uri="s3://abc/125") - - # Create a nested condition: AssetAll with asset1 and AssetAny with asset2 and asset3 - nested_condition = AssetAll(asset1, AssetAny(asset2, asset3)) - - statuses = { - AssetUniqueKey.from_asset(asset1): status_values[0], - AssetUniqueKey.from_asset(asset2): status_values[1], - AssetUniqueKey.from_asset(asset3): status_values[2], - } - - assert nested_condition.evaluate(statuses) == expected_evaluation, "Initial evaluation mismatch" - - serialized_condition = BaseSerialization.serialize(nested_condition) - deserialized_condition = BaseSerialization.deserialize(serialized_condition) - - assert ( - deserialized_condition.evaluate(statuses) == expected_evaluation - ), "Post-serialization evaluation mismatch" - - @pytest.fixture def create_test_assets(): """Fixture to create test assets and corresponding models.""" @@ -500,38 +378,10 @@ def test_normalize_uri_valid_uri(mock_get_normalized_scheme): class TestAssetAlias: - @pytest.fixture - def asset(self): - """Example asset links to asset alias resolved_asset_alias_2.""" - return Asset(uri="test://asset1/", name="test_name", group="asset") - - @pytest.fixture - def asset_alias_1(self): - """Example asset alias links to no assets.""" - asset_alias_1 = AssetAlias(name="test_name", group="test") - with mock.patch.object(asset_alias_1, "_resolve_assets", return_value=[]): - yield asset_alias_1 - - @pytest.fixture - def resolved_asset_alias_2(self, asset): - """Example asset alias links to asset.""" - asset_alias_2 = AssetAlias(name="test_name_2") - with mock.patch.object(asset_alias_2, "_resolve_assets", return_value=[asset]): - yield asset_alias_2 - - @pytest.mark.parametrize("alias_fixture_name", ["asset_alias_1", "resolved_asset_alias_2"]) - def test_as_expression(self, request: pytest.FixtureRequest, alias_fixture_name): - alias = request.getfixturevalue(alias_fixture_name) + def test_as_expression(self): + alias = AssetAlias(name="test_name", group="test") assert alias.as_expression() == {"alias": {"name": alias.name, "group": alias.group}} - def test_evalute_empty(self, asset_alias_1, asset): - assert asset_alias_1.evaluate({AssetUniqueKey.from_asset(asset): True}) is False - assert asset_alias_1._resolve_assets.mock_calls == [mock.call(None)] - - def test_evalute_resolved(self, resolved_asset_alias_2, asset): - assert resolved_asset_alias_2.evaluate({AssetUniqueKey.from_asset(asset): True}) is True - assert resolved_asset_alias_2._resolve_assets.mock_calls == [mock.call(None)] - class TestAssetSubclasses: @pytest.mark.parametrize("subcls, group", ((Model, "model"), (Dataset, "dataset"))) diff --git a/tests/assets/test_evaluation.py b/tests/assets/test_evaluation.py new file mode 100644 index 0000000000000..1c8e909eee180 --- /dev/null +++ b/tests/assets/test_evaluation.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest + +from airflow.assets.evaluation import AssetEvaluator +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny, AssetUniqueKey +from airflow.serialization.serialized_objects import BaseSerialization + +pytestmark = pytest.mark.db_test + +asset1 = Asset(uri="s3://bucket1/data1", name="asset-1") +asset2 = Asset(uri="s3://bucket2/data2", name="asset-2") + + +@pytest.fixture +def evaluator(session): + return AssetEvaluator(session) + + +@pytest.mark.parametrize( + "statuses, result", + [ + ({AssetUniqueKey.from_asset(asset1): True}, True), + ({AssetUniqueKey.from_asset(asset1): False}, False), + ({}, False), + ], +) +def test_asset_evaluate(evaluator, statuses, result): + assert evaluator.run(asset1, statuses) is result + + +@pytest.mark.parametrize( + "condition, statuses, result", + [ + ( + AssetAny(asset1, asset2), + {AssetUniqueKey.from_asset(asset1): False, AssetUniqueKey.from_asset(asset2): True}, + True, + ), + ( + AssetAll(asset1, asset2), + {AssetUniqueKey.from_asset(asset1): True, AssetUniqueKey.from_asset(asset2): False}, + False, + ), + ], +) +def test_assset_boolean_condition_evaluate_iter(evaluator, condition, statuses, result): + """ + Tests _AssetBooleanCondition's evaluate and iter_assets methods through AssetAny and AssetAll. + + Ensures AssetAny evaluate returns True with any true condition, AssetAll evaluate returns False if + any condition is false, and both classes correctly iterate over assets without duplication. + """ + assert evaluator.run(condition, statuses) is result + assert dict(condition.iter_assets()) == { + AssetUniqueKey("asset-1", "s3://bucket1/data1"): asset1, + AssetUniqueKey("asset-2", "s3://bucket2/data2"): asset2, + } + + +@pytest.mark.parametrize( + "inputs, scenario, expected", + [ + # Scenarios for AssetAny + ((True, True, True), "any", True), + ((True, True, False), "any", True), + ((True, False, True), "any", True), + ((True, False, False), "any", True), + ((False, False, True), "any", True), + ((False, True, False), "any", True), + ((False, True, True), "any", True), + ((False, False, False), "any", False), + # Scenarios for AssetAll + ((True, True, True), "all", True), + ((True, True, False), "all", False), + ((True, False, True), "all", False), + ((True, False, False), "all", False), + ((False, False, True), "all", False), + ((False, True, False), "all", False), + ((False, True, True), "all", False), + ((False, False, False), "all", False), + ], +) +def test_asset_logical_conditions_evaluation_and_serialization(evaluator, inputs, scenario, expected): + class_ = AssetAny if scenario == "any" else AssetAll + assets = [Asset(uri=f"s3://abc/{i}", name=f"asset_{i}") for i in range(123, 126)] + condition = class_(*assets) + + statuses = {AssetUniqueKey.from_asset(asset): status for asset, status in zip(assets, inputs)} + assert ( + evaluator.run(condition, statuses) == expected + ), f"Condition evaluation failed for inputs {inputs} and scenario '{scenario}'" + + # Serialize and deserialize the condition to test persistence + serialized = BaseSerialization.serialize(condition) + deserialized = BaseSerialization.deserialize(serialized) + assert evaluator.run(deserialized, statuses) == expected, "Serialization round-trip failed" + + +@pytest.mark.parametrize( + "status_values, expected_evaluation", + [ + pytest.param( + (False, True, True), + False, + id="f & (t | t)", + ), # AssetAll requires all conditions to be True, but asset1 is False + pytest.param( + (True, True, True), + True, + id="t & (t | t)", + ), # All conditions are True + pytest.param( + (True, False, True), + True, + id="t & (f | t)", + ), # asset1 is True, and AssetAny condition (asset2 or asset3 being True) is met + pytest.param( + (True, False, False), + False, + id="t & (f | f)", + ), # asset1 is True, but neither asset2 nor asset3 meet the AssetAny condition + ], +) +def test_nested_asset_conditions_with_serialization(evaluator, status_values, expected_evaluation): + # Define assets + asset1 = Asset(uri="s3://abc/123") + asset2 = Asset(uri="s3://abc/124") + asset3 = Asset(uri="s3://abc/125") + + # Create a nested condition: AssetAll with asset1 and AssetAny with asset2 and asset3 + nested_condition = AssetAll(asset1, AssetAny(asset2, asset3)) + + statuses = { + AssetUniqueKey.from_asset(asset1): status_values[0], + AssetUniqueKey.from_asset(asset2): status_values[1], + AssetUniqueKey.from_asset(asset3): status_values[2], + } + + assert evaluator.run(nested_condition, statuses) == expected_evaluation, "Initial evaluation mismatch" + + serialized_condition = BaseSerialization.serialize(nested_condition) + deserialized_condition = BaseSerialization.deserialize(serialized_condition) + + assert ( + evaluator.run(deserialized_condition, statuses) == expected_evaluation + ), "Post-serialization evaluation mismatch" + + +class TestAssetAlias: + @pytest.fixture + def asset(self): + """Example asset links to asset alias resolved_asset_alias_2.""" + return Asset(uri="test://asset1/", name="test_name", group="asset") + + @pytest.fixture + def asset_alias_1(self): + """Example asset alias links to no assets.""" + return AssetAlias(name="test_name", group="test") + + @pytest.fixture + def resolved_asset_alias_2(self): + """Example asset alias links to asset.""" + return AssetAlias(name="test_name_2") + + @pytest.fixture + def evaluator(self, session, asset_alias_1, resolved_asset_alias_2, asset): + class _AssetEvaluator(AssetEvaluator): # Can't use mock because AssetEvaluator sets __slots__. + def _resolve_asset_alias(self, o): + if o is asset_alias_1: + return [] + elif o is resolved_asset_alias_2: + return [asset] + return super()._resolve_asset_alias(o) + + return _AssetEvaluator(session) + + def test_evaluate_empty(self, evaluator, asset_alias_1, asset): + assert evaluator.run(asset_alias_1, {AssetUniqueKey.from_asset(asset): True}) is False + + def test_evalute_resolved(self, evaluator, resolved_asset_alias_2, asset): + assert evaluator.run(resolved_asset_alias_2, {AssetUniqueKey.from_asset(asset): True}) is True diff --git a/tests/models/test_asset.py b/tests/models/test_asset.py index 1e21252c5028f..5608d4c8d59e0 100644 --- a/tests/models/test_asset.py +++ b/tests/models/test_asset.py @@ -72,7 +72,7 @@ def resolved_asset_alias_2(self, session, asset_model): return asset_alias_2 def test_expand_alias_to_assets_empty(self, session, asset_alias_1): - assert expand_alias_to_assets(asset_alias_1.name, session) == [] + assert list(expand_alias_to_assets(asset_alias_1.name, session=session)) == [] def test_expand_alias_to_assets_resolved(self, session, resolved_asset_alias_2, asset_model): - assert expand_alias_to_assets(resolved_asset_alias_2.name, session) == [asset_model] + assert list(expand_alias_to_assets(resolved_asset_alias_2.name, session=session)) == [asset_model] diff --git a/tests/timetables/test_assets_timetable.py b/tests/timetables/test_assets_timetable.py index 9892b5805bd8d..d8386bc27f5d9 100644 --- a/tests/timetables/test_assets_timetable.py +++ b/tests/timetables/test_assets_timetable.py @@ -273,8 +273,11 @@ def create_test_assets(self): return [Asset(uri=f"test://asset{i}", name=f"hello{i}") for i in range(1, 3)] def test_asset_dag_run_queue_processing(self, session, dag_maker, create_test_assets): + from airflow.assets.evaluation import AssetEvaluator + assets = create_test_assets asset_models = session.query(AssetModel).all() + evaluator = AssetEvaluator(session) with dag_maker(schedule=AssetAny(*assets)) as dag: EmptyOperator(task_id="hello") @@ -298,7 +301,7 @@ def test_asset_dag_run_queue_processing(self, session, dag_maker, create_test_as dag = SerializedDAG.deserialize(serialized_dag.data) for asset_uri, status in dag_statuses[dag.dag_id].items(): cond = dag.timetable.asset_condition - assert cond.evaluate({asset_uri: status}), "DAG trigger evaluation failed" + assert evaluator.run(cond, {asset_uri: status}), "DAG trigger evaluation failed" def test_dag_with_complex_asset_condition(self, session, dag_maker): # Create Asset instances From d9acd62281ecd2e0125fa22fb45caefb991b83fe Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Fri, 7 Mar 2025 17:12:25 +0800 Subject: [PATCH 2/2] Remove dependency resolution in DAG parser Dependency resolution requires database calls for non-direct asset references, but we don't want that to happen in the SDK. This removes the eager resolution code so all asset aliases and refs are no longer resolved, but one only keeps a marker for itself. An additional PR will be submitted later to do the resolution. This is a part of the asset UI project that involves changes in the API anyway. --- .../airflow/sdk/definitions/asset/__init__.py | 50 +++++-------------- 1 file changed, 12 insertions(+), 38 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py index 5a01a745d9cad..94559988c9d4a 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -482,20 +482,12 @@ def iter_asset_refs(self) -> Iterator[AssetRef]: def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]: (dependency_id,) = attrs.astuple(self) - if asset := self._resolve_asset(): - yield DagDependency( - source=f"asset-ref:{dependency_id}" if source else "asset", - target="asset" if source else f"asset-ref:{dependency_id}", - dependency_type="asset", - dependency_id=asset.name, - ) - else: - yield DagDependency( - source=source or "asset-ref", - target=target or "asset-ref", - dependency_type="asset-ref", - dependency_id=dependency_id, - ) + yield DagDependency( + source=source or "asset-ref", + target=target or "asset-ref", + dependency_type="asset-ref", + dependency_id=dependency_id, + ) @attrs.define(hash=True) @@ -554,30 +546,12 @@ def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterat :meta private: """ - if not (resolved_assets := self._resolve_assets()): - yield DagDependency( - source=source or "asset-alias", - target=target or "asset-alias", - dependency_type="asset-alias", - dependency_id=self.name, - ) - return - for asset in resolved_assets: - asset_name = asset.name - # asset - yield DagDependency( - source=f"asset-alias:{self.name}" if source else "asset", - target="asset" if source else f"asset-alias:{self.name}", - dependency_type="asset", - dependency_id=asset_name, - ) - # asset alias - yield DagDependency( - source=source or f"asset:{asset_name}", - target=target or f"asset:{asset_name}", - dependency_type="asset-alias", - dependency_id=self.name, - ) + yield DagDependency( + source=source or "asset-alias", + target=target or "asset-alias", + dependency_type="asset-alias", + dependency_id=self.name, + ) class AssetBooleanCondition(BaseAsset):