diff --git a/airflow-core/src/airflow/executors/workloads/trigger.py b/airflow-core/src/airflow/executors/workloads/trigger.py index edde48f7f73f9..d3b2d0627a7ec 100644 --- a/airflow-core/src/airflow/executors/workloads/trigger.py +++ b/airflow-core/src/airflow/executors/workloads/trigger.py @@ -46,3 +46,6 @@ class RunTrigger(BaseModel): dag_run_data: dict | None = ( None # Serialized DagRun data in dict format so it can be deserialized in trigger subprocess. ) + + # name: uri of all "watched" Assets + watched_assets: dict[str, str] | None = None # Set for BaseEventTrigger asset watchers only diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 3612feb011fc3..b32ea53a2cd7f 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -58,6 +58,7 @@ from airflow.models.trigger import Trigger from airflow.observability.metrics import stats_utils from airflow.sdk.api.datamodels._generated import HITLDetailResponse +from airflow.sdk.definitions.asset import Asset from airflow.sdk.execution_time.comms import ( CommsDecoder, ConnectionResult, @@ -89,6 +90,7 @@ _new_encoder, _RequestFrame, ) +from airflow.sdk.execution_time.context import AssetStateStoreAccessors from airflow.sdk.execution_time.request_handlers import ( handle_delete_variable, handle_delete_xcom, @@ -778,10 +780,16 @@ def _create_workload( session: Session, ) -> workloads.RunTrigger | None: if trigger.task_instance is None: + watched_assets: dict[str, str] | None = None + + if trigger.assets: + watched_assets = {a.name: a.uri for a in trigger.assets} + return workloads.RunTrigger( id=trigger.id, classpath=trigger.classpath, encrypted_kwargs=trigger.encrypted_kwargs, + watched_assets=watched_assets, ) if not trigger.task_instance.dag_version_id: @@ -1313,6 +1321,11 @@ async def create_triggers(self): trigger_instance.triggerer_job_id = self.job_id trigger_instance.timeout_after = workload.timeout_after + if isinstance(trigger_instance, BaseEventTrigger) and workload.watched_assets: + trigger_instance.asset_state_store = AssetStateStoreAccessors( + inlets=[Asset(name=name, uri=uri) for name, uri in workload.watched_assets.items()] + ) + self.triggers[trigger_id] = { "task": asyncio.create_task( self.run_trigger(trigger_id, trigger_instance, workload.timeout_after, context), diff --git a/airflow-core/src/airflow/triggers/base.py b/airflow-core/src/airflow/triggers/base.py index f587fd8de3ff4..ee57b536d5ec7 100644 --- a/airflow-core/src/airflow/triggers/base.py +++ b/airflow-core/src/airflow/triggers/base.py @@ -46,6 +46,7 @@ from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance from airflow.sdk.definitions.context import Context + from airflow.sdk.execution_time.context import AssetStateStoreAccessors from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.triggers.shared_stream import SharedStreamProducer @@ -296,6 +297,12 @@ class BaseEventTrigger(BaseTrigger): supports_triggerer_queue: bool = False + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # Injected by the triggerer before run() is called + self.asset_state_store: AssetStateStoreAccessors | None = None + @staticmethod def hash(classpath: str, kwargs: dict[str, Any]) -> int: """ diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 2a875fbfbf6c1..4ae539600d47b 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -51,6 +51,7 @@ from airflow._shared.timezones import timezone from airflow.executors import workloads from airflow.executors.workloads.task import TaskInstanceDTO +from airflow.executors.workloads.trigger import RunTrigger from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import ( _USER_ACTION_CANCEL_MSG, @@ -74,8 +75,17 @@ from airflow.providers.standard.operators.python import PythonOperator from airflow.providers.standard.triggers.file import FileDeleteTrigger from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger -from airflow.sdk import DAG, BaseHook, BaseOperator -from airflow.sdk.execution_time.comms import ToSupervisor, ToTask, _RequestFrame, _ResponseFrame +from airflow.sdk import DAG, Asset, BaseHook, BaseOperator +from airflow.sdk.execution_time.comms import ( + AssetStateStoreResult, + GetAssetStateStoreByName, + SetAssetStateStoreByName, + ToSupervisor, + ToTask, + _RequestFrame, + _ResponseFrame, +) +from airflow.sdk.execution_time.context import AssetStateStoreAccessors from airflow.serialization.serialized_objects import LazyDeserializedDAG from airflow.triggers.base import BaseEventTrigger, BaseTrigger, TriggerEvent from airflow.triggers.shared_stream import SharedStreamProducer @@ -561,6 +571,269 @@ def test_create_workload_uses_supervisor_id_without_job(jobless_supervisor, mock assert factory.log_path == f"/logs/ti.trigger.{jobless_supervisor.id}.log" +def test_create_workload_sets_watched_assets_for_asset_only_trigger(jobless_supervisor, mocker): + """_create_workload() should populate watched_assets when trigger.task_instance is None and assets exist.""" + asset1 = mocker.Mock(spec=Asset) + asset1.name = "my_asset" + asset1.uri = "s3://bucket/key" + + asset2 = mocker.Mock(spec=Asset) + asset2.name = "other_asset" + asset2.uri = "gs://bucket/path" + + trigger = mocker.Mock(spec=BaseEventTrigger) + trigger.id = 42 + trigger.classpath = "some.path.Trigger" + trigger.encrypted_kwargs = "encrypted" + trigger.task_instance = None # Not tied to a Task (similar to a BaseEventTrigger) + trigger.assets = [asset1, asset2] + + workload = jobless_supervisor._create_workload( + trigger=trigger, + dag_bag=mocker.Mock(), + render_log_fname=mocker.Mock(), + session=mocker.Mock(), + ) + + assert workload is not None + assert workload.watched_assets == {"my_asset": "s3://bucket/key", "other_asset": "gs://bucket/path"} + + +def test_create_workload_watched_assets_none_when_no_assets(jobless_supervisor, mocker): + """_create_workload() should set watched_assets=None when trigger.task_instance is None and assets is empty.""" + trigger = mocker.Mock(spec=BaseEventTrigger) + trigger.id = 43 + trigger.classpath = "some.path.Trigger" + trigger.encrypted_kwargs = "encrypted" + trigger.task_instance = None + trigger.assets = [] # No Assets are attached to the trigger + + workload = jobless_supervisor._create_workload( + trigger=trigger, + dag_bag=mocker.Mock(), + render_log_fname=mocker.Mock(), + session=mocker.Mock(), + ) + + assert workload is not None + assert workload.watched_assets is None + + +def test_run_trigger_workload_includes_watched_assets_field(): + """RunTrigger workload should accept and store watched_assets.""" + workload = RunTrigger( + id=1, + classpath="airflow.triggers.testing.SuccessTrigger", + encrypted_kwargs="fake", + watched_assets={"asset_a": "s3://a", "asset_b": "gs://b"}, + ) + assert workload.watched_assets == {"asset_a": "s3://a", "asset_b": "gs://b"} + + +def test_run_trigger_workload_watched_assets_defaults_to_none(): + """RunTrigger workload watched_assets should default to None.""" + workload = RunTrigger( + id=1, + classpath="airflow.triggers.testing.SuccessTrigger", + encrypted_kwargs="fake", + ) + assert workload.watched_assets is None + + +@pytest.fixture +def make_watcher_trigger(): + """Factory fixture: call with a list to get a BaseEventTrigger subclass that appends each new instance.""" + + def factory(injected_instances): + class WatcherTrigger(BaseEventTrigger): + def __init__(self, **kwargs): + super().__init__(**kwargs) + injected_instances.append(self) + + def serialize(self): + return (f"{type(self).__module__}.{type(self).__qualname__}", {}) + + async def run(self): + yield TriggerEvent("done") + + return WatcherTrigger + + return factory + + +@pytest.mark.asyncio +@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.get_trigger_by_classpath") +async def test_create_triggers_injects_asset_state_store_for_base_event_trigger( + mock_get_classpath, session, make_watcher_trigger +): + """asset_state_store is populated on BaseEventTrigger instances when watched_assets is set.""" + injected_instances = [] + mock_get_classpath.return_value = make_watcher_trigger(injected_instances) + + runner = TriggerRunner() + runner.to_create.append( + workloads.RunTrigger.model_construct( + id=10, + ti=None, + classpath="fake.WatcherTrigger", + encrypted_kwargs="{}", + watched_assets={"my_asset": "s3://bucket/key"}, + ) + ) + + await runner.create_triggers() + + # This is only testing that an exception was NOT thrown when creating the Trigger + assert 10 in runner.triggers + + assert len(injected_instances) == 1 + assert injected_instances[0].asset_state_store is not None + assert isinstance(injected_instances[0].asset_state_store, AssetStateStoreAccessors) + + runner.triggers[10]["task"].cancel() + await runner.cleanup_finished_triggers() + + +@pytest.mark.asyncio +@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.get_trigger_by_classpath") +async def test_create_triggers_asset_state_store_none_when_no_watched_assets( + mock_get_classpath, session, make_watcher_trigger +): + """asset_state_store stays None when watched_assets is not set on the workload.""" + injected_instances = [] + mock_get_classpath.return_value = make_watcher_trigger(injected_instances) + + runner = TriggerRunner() + runner.to_create.append( + workloads.RunTrigger.model_construct( + id=11, + ti=None, + classpath="fake.WatcherTrigger", + encrypted_kwargs="{}", + watched_assets=None, + ) + ) + + await runner.create_triggers() + + assert len(injected_instances) == 1 + assert injected_instances[0].asset_state_store is None + + runner.triggers[11]["task"].cancel() + await runner.cleanup_finished_triggers() + + +@pytest.mark.asyncio +@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.get_trigger_by_classpath") +async def test_create_triggers_skips_asset_state_store_for_non_event_trigger(mock_get_classpath, session): + """asset_state_store injection is skipped for plain BaseTrigger (non-BaseEventTrigger) instances.""" + injected_instances: list[BaseTrigger] = [] + + class PlainTrigger(BaseTrigger): + def __init__(self, **kwargs): + super().__init__(**kwargs) + injected_instances.append(self) + + def serialize(self): + return (f"{type(self).__module__}.{type(self).__qualname__}", {}) + + async def run(self): + yield TriggerEvent("done") + + mock_get_classpath.return_value = PlainTrigger + + runner = TriggerRunner() + runner.to_create.append( + workloads.RunTrigger.model_construct( + id=12, ti=None, classpath="fake.PlainTrigger", encrypted_kwargs="{}" + ) + ) + + await runner.create_triggers() + + assert 12 in runner.triggers + assert len(injected_instances) == 1 + assert not hasattr(injected_instances[0], "asset_state_store") + + runner.triggers[12]["task"].cancel() + await runner.cleanup_finished_triggers() + + +@pytest.mark.asyncio +@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.get_trigger_by_classpath") +async def test_create_triggers_asset_state_store_contains_correct_assets( + mock_get_classpath, session, make_watcher_trigger +): + """AssetStateStoreAccessors built from watched_assets has entries for all provided name/URI pairs.""" + injected_instances = [] + mock_get_classpath.return_value = make_watcher_trigger(injected_instances) + + runner = TriggerRunner() + runner.to_create.append( + workloads.RunTrigger.model_construct( + id=13, + ti=None, + classpath="fake.WatcherTrigger", + encrypted_kwargs="{}", + watched_assets={"asset_a": "s3://bucket/a", "asset_b": "gs://bucket/b"}, + ) + ) + + await runner.create_triggers() + + assert len(injected_instances) == 1 + state_store = injected_instances[0].asset_state_store + + assert state_store is not None + assert isinstance(state_store, AssetStateStoreAccessors) + assert state_store[Asset(name="asset_a", uri="s3://bucket/a")] is not None + assert state_store[Asset(name="asset_b", uri="gs://bucket/b")] is not None + + runner.triggers[13]["task"].cancel() + await runner.cleanup_finished_triggers() + + +@pytest.mark.asyncio +@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.get_trigger_by_classpath") +async def test_create_triggers_asset_state_store_accessor_reads_and_writes( + mock_get_classpath, session, mock_supervisor_comms, make_watcher_trigger +): + """asset_state_store accessor sends correct SUPERVISOR_COMMS messages on get() and set().""" + injected_instances = [] + mock_get_classpath.return_value = make_watcher_trigger(injected_instances) + + runner = TriggerRunner() + runner.to_create.append( + workloads.RunTrigger.model_construct( + id=14, + ti=None, + classpath="fake.WatcherTrigger", + encrypted_kwargs="{}", + watched_assets={"asset_a": "s3://bucket/a"}, + ) + ) + + await runner.create_triggers() + + assert len(injected_instances) == 1 + state_store = injected_instances[0].asset_state_store + accessor = state_store[Asset(name="asset_a", uri="s3://bucket/a")] + + mock_supervisor_comms.send.return_value = AssetStateStoreResult(value="2026-01-01") + result = accessor.get("watermark") + assert result == "2026-01-01" + + mock_supervisor_comms.send.assert_called_with(GetAssetStateStoreByName(name="asset_a", key="watermark")) + + accessor.set("watermark", "2026-06-11") + mock_supervisor_comms.send.assert_called_with( + SetAssetStateStoreByName(name="asset_a", key="watermark", value="2026-06-11") + ) + + runner.triggers[14]["task"].cancel() + await runner.cleanup_finished_triggers() + + def test_trigger_lifecycle(spy_agency: SpyAgency, session, testing_dag_bundle): """ Checks that the triggerer will correctly see a new Trigger in the database diff --git a/airflow-core/tests/unit/triggers/test_base_trigger.py b/airflow-core/tests/unit/triggers/test_base_trigger.py index 429e54615293a..06cb06c94ce9f 100644 --- a/airflow-core/tests/unit/triggers/test_base_trigger.py +++ b/airflow-core/tests/unit/triggers/test_base_trigger.py @@ -17,9 +17,12 @@ # under the License. from __future__ import annotations +from unittest.mock import create_autospec + import pytest from airflow.sdk.bases.operator import BaseOperator +from airflow.sdk.execution_time.context import AssetStateStoreAccessors from airflow.triggers.base import BaseEventTrigger, BaseTrigger, StartTriggerArgs, TriggerEvent @@ -253,6 +256,28 @@ async def stream(): assert [p["region"] for p in payloads] == ["us", "us"] +def test_base_event_trigger_asset_state_store_initialized_to_none(): + """asset_state_store is None before it is set.""" + trigger = _PlainEventTrigger() + assert trigger.asset_state_store is None + + +def test_base_event_trigger_asset_state_store_can_be_set(): + """asset_state_store can be set once the Trigger is initialized.""" + trigger = _PlainEventTrigger() + mock_store = create_autospec(AssetStateStoreAccessors, instance=True) + trigger.asset_state_store = mock_store + assert trigger.asset_state_store is mock_store + + +def test_base_event_trigger_asset_state_store_independent_across_instances(): + """a.asset_state_store does not impact b.asset_state_store.""" + a = _PlainEventTrigger(name="a") + b = _PlainEventTrigger(name="b") + a.asset_state_store = create_autospec(AssetStateStoreAccessors, instance=True) + assert b.asset_state_store is None + + def test_create_shared_stream_producer_raises_by_default(): """A subclass that does not override create_shared_stream_producer gets NotImplementedError.