diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index f3fee689928a0..7a77ed9ad23c5 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -2064,6 +2064,43 @@ def _send_error_email_notification( log.exception("Failed to send email notification") +@detail_span("task.execute") +def _run_execute_callable( + context: Context, + execute: Callable[..., Any] | functools.partial[Any], + task: BaseOperator, +) -> Any: + """ + Run the task's execute callable, applying the execution timeout if one is set. + + The contextvars snapshot is taken here, after the ``task.execute`` span is + current, so spans the operator emits during ``execute`` nest under it rather + than under the caller. ``ExecutorSafeguard``'s tracker is set into that copy + so the operator's ``execute`` passes the safeguard check, while the copy keeps + the change from leaking into the surrounding context. + """ + ctx = contextvars.copy_context() + ctx.run(ExecutorSafeguard.tracker.set, task) + if task.execution_timeout: + from airflow.sdk.execution_time.timeout import timeout + + # TODO: handle timeout in case of deferral + timeout_seconds = task.execution_timeout.total_seconds() + try: + # It's possible we're already timed out, so fast-fail if true + if timeout_seconds <= 0: + raise AirflowTaskTimeout() + # Run task in timeout wrapper + with timeout(timeout_seconds): + result = ctx.run(execute, context=context) + except AirflowTaskTimeout: + task.on_kill() + raise + else: + result = ctx.run(execute, context=context) + return result + + @detail_span("_execute_task") def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): """Execute Task (optionally with a Timeout) and push Xcom results.""" @@ -2087,10 +2124,6 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): assert isinstance(kwargs, dict) execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs) - ctx = contextvars.copy_context() - # Populate the context var so ExecutorSafeguard doesn't complain - ctx.run(ExecutorSafeguard.tracker.set, task) - # Export context in os.environ to make it available for operators to use. airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) os.environ.update(airflow_context_vars) @@ -2106,23 +2139,7 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): log.info("::endgroup::") - if task.execution_timeout: - from airflow.sdk.execution_time.timeout import timeout - - # TODO: handle timeout in case of deferral - timeout_seconds = task.execution_timeout.total_seconds() - try: - # It's possible we're already timed out, so fast-fail if true - if timeout_seconds <= 0: - raise AirflowTaskTimeout() - # Run task in timeout wrapper - with timeout(timeout_seconds): - result = ctx.run(execute, context=context) - except AirflowTaskTimeout: - task.on_kill() - raise - else: - result = ctx.run(execute, context=context) + result = _run_execute_callable(context, execute, task) if (post_execute_hook := task._post_execute_hook) is not None: create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 64493f6305f3c..1aad1a02d5eab 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -18,6 +18,7 @@ from __future__ import annotations import contextlib +import contextvars import functools import json import os @@ -73,6 +74,7 @@ TaskInstanceState, TIRunContext, ) +from airflow.sdk.bases.operator import ExecutorSafeguard from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions._internal.types import NOTSET, SET_DURING_EXECUTION, is_arg_set from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, AssetUriRef, Dataset, Model @@ -167,6 +169,7 @@ _make_task_span, _push_xcom_if_needed, _register_deserialization_allowed_classes, + _run_execute_callable, _serialize_outlet_events, _xcom_push, detail_span, @@ -5597,6 +5600,143 @@ def test_exception_in_context_manager_propagates(self): raise ValueError("boom") +class TestRunExecuteCallable: + """Tests for ``_run_execute_callable``. + + It runs the task's execute callable inside an isolated contextvars copy (with + the ExecutorSafeguard tracker set), applies the execution timeout when one is + configured, and wraps the call in a ``task.execute`` detail span. + """ + + @pytest.fixture(autouse=True) + def _sampled_carrier_provider(self): + """Make new_dagrun_trace_carrier produce a SAMPLED carrier (see TestDetailSpan).""" + provider = TracerProvider() + with mock.patch( + "airflow._shared.observability.traces.trace.get_tracer_provider", + return_value=provider, + ): + yield + + @staticmethod + def _make_task(execution_timeout=None): + task = mock.MagicMock(spec=BaseOperator) + task.execution_timeout = execution_timeout + return task + + def test_runs_in_isolated_context_with_safeguard_tracker_set(self): + """The callable runs in an internal context copy that has the safeguard tracker set and does not leak.""" + var = contextvars.ContextVar("marker") + var.set("outer") + task = self._make_task() + seen = {} + + def execute(context): + var.set("inner") + seen["tracker"] = ExecutorSafeguard.tracker.get(None) + return context["value"] * 2 + + result = _run_execute_callable(context={"value": 21}, execute=execute, task=task) + + assert result == 42 + # The safeguard tracker is set to the task inside the copy used to run execute. + assert seen["tracker"] is task + # The mutation happened inside the copy, so it does not leak to the caller's context. + assert var.get() == "outer" + # The .set was confined to the copy, so the tracker never leaked to the caller's context. + assert ExecutorSafeguard.tracker.get(None) is not task + task.on_kill.assert_not_called() + + def test_applies_execution_timeout(self): + """When a timeout is set and the callable overruns, AirflowTaskTimeout is raised and on_kill is called.""" + task = self._make_task(execution_timeout=timedelta(milliseconds=10)) + + def execute(context): + time.sleep(2) + + with pytest.raises(AirflowTaskTimeout): + _run_execute_callable(context={}, execute=execute, task=task) + + task.on_kill.assert_called_once() + + def test_fast_fails_when_timeout_already_elapsed(self): + """A non-positive timeout fast-fails before running the callable and still calls on_kill.""" + task = self._make_task(execution_timeout=timedelta(seconds=-1)) + execute = mock.MagicMock() + + with pytest.raises(AirflowTaskTimeout): + _run_execute_callable(context={}, execute=execute, task=task) + + execute.assert_not_called() + task.on_kill.assert_called_once() + + def test_emits_task_execute_span_at_detail_level_2(self): + """At detail level 2, running the callable produces a recorded ``task.execute`` span.""" + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + t = provider.get_tracer("test") + carrier = new_dagrun_trace_carrier(task_span_detail_level=2) + parent_ctx = TraceContextTextMapPropagator().extract(carrier) + + task = self._make_task() + + with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t): + with t.start_as_current_span("parent", context=parent_ctx): + result = _run_execute_callable(context={}, execute=lambda context: "ok", task=task) + + assert result == "ok" + names = [s.name for s in exporter.get_finished_spans()] + assert "task.execute" in names + + def test_operator_child_spans_nest_under_task_execute(self): + """Spans the operator emits during execute nest under ``task.execute``, not its caller. + + The contextvars snapshot is taken inside ``_run_execute_callable`` after the + ``task.execute`` span is current, so a span started during execute parents to + ``task.execute`` rather than to the surrounding span. + """ + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + t = provider.get_tracer("test") + carrier = new_dagrun_trace_carrier(task_span_detail_level=2) + parent_ctx = TraceContextTextMapPropagator().extract(carrier) + + task = self._make_task() + + def execute(context): + with t.start_as_current_span("operator_child"): + return "ok" + + with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t): + with t.start_as_current_span("parent", context=parent_ctx): + result = _run_execute_callable(context={}, execute=execute, task=task) + + assert result == "ok" + spans = {s.name: s for s in exporter.get_finished_spans()} + assert spans["operator_child"].parent.span_id == spans["task.execute"].context.span_id + + def test_no_task_execute_span_at_detail_level_1(self): + """At detail level 1, no ``task.execute`` span is recorded but the callable still runs.""" + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + t = provider.get_tracer("test") + carrier = new_dagrun_trace_carrier(task_span_detail_level=1) + parent_ctx = TraceContextTextMapPropagator().extract(carrier) + + task = self._make_task() + + with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t): + with t.start_as_current_span("parent", context=parent_ctx): + result = _run_execute_callable(context={}, execute=lambda context: "ok", task=task) + + assert result == "ok" + names = [s.name for s in exporter.get_finished_spans()] + assert "task.execute" not in names + + def test_dag_add_result(create_runtime_ti, mock_supervisor_comms): with DAG(dag_id="test_dag_add_result") as dag: task = PythonOperator(task_id="t", python_callable=lambda: 123)