From e713231c0bef3d2932128ed9cf0247bbfb17358e Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Wed, 1 Jul 2026 12:06:05 -0700 Subject: [PATCH] Add task.execute detail span around task execute callable (#67877) When task span detail level is greater than 1, the actual execute call was not separately traced, making it hard to see how much of a task's runtime was spent in the operator's own work versus the surrounding setup. Wrapping the execute call in its own span gives that finer-grained breakdown. The contextvars context the callable runs in is snapshotted inside the new helper, after the span is current, so spans the operator emits during execute nest under it rather than alongside it. (cherry picked from commit b006a978d204ad72f88b9c0633418facd2c00330) --- .../airflow/sdk/execution_time/task_runner.py | 59 +++++--- .../execution_time/test_task_runner.py | 140 ++++++++++++++++++ 2 files changed, 178 insertions(+), 21 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index f3fee689928a0..7a77ed9ad23c5 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -2064,6 +2064,43 @@ def _send_error_email_notification( log.exception("Failed to send email notification") +@detail_span("task.execute") +def _run_execute_callable( + context: Context, + execute: Callable[..., Any] | functools.partial[Any], + task: BaseOperator, +) -> Any: + """ + Run the task's execute callable, applying the execution timeout if one is set. + + The contextvars snapshot is taken here, after the ``task.execute`` span is + current, so spans the operator emits during ``execute`` nest under it rather + than under the caller. ``ExecutorSafeguard``'s tracker is set into that copy + so the operator's ``execute`` passes the safeguard check, while the copy keeps + the change from leaking into the surrounding context. + """ + ctx = contextvars.copy_context() + ctx.run(ExecutorSafeguard.tracker.set, task) + if task.execution_timeout: + from airflow.sdk.execution_time.timeout import timeout + + # TODO: handle timeout in case of deferral + timeout_seconds = task.execution_timeout.total_seconds() + try: + # It's possible we're already timed out, so fast-fail if true + if timeout_seconds <= 0: + raise AirflowTaskTimeout() + # Run task in timeout wrapper + with timeout(timeout_seconds): + result = ctx.run(execute, context=context) + except AirflowTaskTimeout: + task.on_kill() + raise + else: + result = ctx.run(execute, context=context) + return result + + @detail_span("_execute_task") def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): """Execute Task (optionally with a Timeout) and push Xcom results.""" @@ -2087,10 +2124,6 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): assert isinstance(kwargs, dict) execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs) - ctx = contextvars.copy_context() - # Populate the context var so ExecutorSafeguard doesn't complain - ctx.run(ExecutorSafeguard.tracker.set, task) - # Export context in os.environ to make it available for operators to use. airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) os.environ.update(airflow_context_vars) @@ -2106,23 +2139,7 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): log.info("::endgroup::") - if task.execution_timeout: - from airflow.sdk.execution_time.timeout import timeout - - # TODO: handle timeout in case of deferral - timeout_seconds = task.execution_timeout.total_seconds() - try: - # It's possible we're already timed out, so fast-fail if true - if timeout_seconds <= 0: - raise AirflowTaskTimeout() - # Run task in timeout wrapper - with timeout(timeout_seconds): - result = ctx.run(execute, context=context) - except AirflowTaskTimeout: - task.on_kill() - raise - else: - result = ctx.run(execute, context=context) + result = _run_execute_callable(context, execute, task) if (post_execute_hook := task._post_execute_hook) is not None: create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 64493f6305f3c..1aad1a02d5eab 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -18,6 +18,7 @@ from __future__ import annotations import contextlib +import contextvars import functools import json import os @@ -73,6 +74,7 @@ TaskInstanceState, TIRunContext, ) +from airflow.sdk.bases.operator import ExecutorSafeguard from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions._internal.types import NOTSET, SET_DURING_EXECUTION, is_arg_set from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, AssetUriRef, Dataset, Model @@ -167,6 +169,7 @@ _make_task_span, _push_xcom_if_needed, _register_deserialization_allowed_classes, + _run_execute_callable, _serialize_outlet_events, _xcom_push, detail_span, @@ -5597,6 +5600,143 @@ def test_exception_in_context_manager_propagates(self): raise ValueError("boom") +class TestRunExecuteCallable: + """Tests for ``_run_execute_callable``. + + It runs the task's execute callable inside an isolated contextvars copy (with + the ExecutorSafeguard tracker set), applies the execution timeout when one is + configured, and wraps the call in a ``task.execute`` detail span. + """ + + @pytest.fixture(autouse=True) + def _sampled_carrier_provider(self): + """Make new_dagrun_trace_carrier produce a SAMPLED carrier (see TestDetailSpan).""" + provider = TracerProvider() + with mock.patch( + "airflow._shared.observability.traces.trace.get_tracer_provider", + return_value=provider, + ): + yield + + @staticmethod + def _make_task(execution_timeout=None): + task = mock.MagicMock(spec=BaseOperator) + task.execution_timeout = execution_timeout + return task + + def test_runs_in_isolated_context_with_safeguard_tracker_set(self): + """The callable runs in an internal context copy that has the safeguard tracker set and does not leak.""" + var = contextvars.ContextVar("marker") + var.set("outer") + task = self._make_task() + seen = {} + + def execute(context): + var.set("inner") + seen["tracker"] = ExecutorSafeguard.tracker.get(None) + return context["value"] * 2 + + result = _run_execute_callable(context={"value": 21}, execute=execute, task=task) + + assert result == 42 + # The safeguard tracker is set to the task inside the copy used to run execute. + assert seen["tracker"] is task + # The mutation happened inside the copy, so it does not leak to the caller's context. + assert var.get() == "outer" + # The .set was confined to the copy, so the tracker never leaked to the caller's context. + assert ExecutorSafeguard.tracker.get(None) is not task + task.on_kill.assert_not_called() + + def test_applies_execution_timeout(self): + """When a timeout is set and the callable overruns, AirflowTaskTimeout is raised and on_kill is called.""" + task = self._make_task(execution_timeout=timedelta(milliseconds=10)) + + def execute(context): + time.sleep(2) + + with pytest.raises(AirflowTaskTimeout): + _run_execute_callable(context={}, execute=execute, task=task) + + task.on_kill.assert_called_once() + + def test_fast_fails_when_timeout_already_elapsed(self): + """A non-positive timeout fast-fails before running the callable and still calls on_kill.""" + task = self._make_task(execution_timeout=timedelta(seconds=-1)) + execute = mock.MagicMock() + + with pytest.raises(AirflowTaskTimeout): + _run_execute_callable(context={}, execute=execute, task=task) + + execute.assert_not_called() + task.on_kill.assert_called_once() + + def test_emits_task_execute_span_at_detail_level_2(self): + """At detail level 2, running the callable produces a recorded ``task.execute`` span.""" + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + t = provider.get_tracer("test") + carrier = new_dagrun_trace_carrier(task_span_detail_level=2) + parent_ctx = TraceContextTextMapPropagator().extract(carrier) + + task = self._make_task() + + with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t): + with t.start_as_current_span("parent", context=parent_ctx): + result = _run_execute_callable(context={}, execute=lambda context: "ok", task=task) + + assert result == "ok" + names = [s.name for s in exporter.get_finished_spans()] + assert "task.execute" in names + + def test_operator_child_spans_nest_under_task_execute(self): + """Spans the operator emits during execute nest under ``task.execute``, not its caller. + + The contextvars snapshot is taken inside ``_run_execute_callable`` after the + ``task.execute`` span is current, so a span started during execute parents to + ``task.execute`` rather than to the surrounding span. + """ + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + t = provider.get_tracer("test") + carrier = new_dagrun_trace_carrier(task_span_detail_level=2) + parent_ctx = TraceContextTextMapPropagator().extract(carrier) + + task = self._make_task() + + def execute(context): + with t.start_as_current_span("operator_child"): + return "ok" + + with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t): + with t.start_as_current_span("parent", context=parent_ctx): + result = _run_execute_callable(context={}, execute=execute, task=task) + + assert result == "ok" + spans = {s.name: s for s in exporter.get_finished_spans()} + assert spans["operator_child"].parent.span_id == spans["task.execute"].context.span_id + + def test_no_task_execute_span_at_detail_level_1(self): + """At detail level 1, no ``task.execute`` span is recorded but the callable still runs.""" + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + t = provider.get_tracer("test") + carrier = new_dagrun_trace_carrier(task_span_detail_level=1) + parent_ctx = TraceContextTextMapPropagator().extract(carrier) + + task = self._make_task() + + with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t): + with t.start_as_current_span("parent", context=parent_ctx): + result = _run_execute_callable(context={}, execute=lambda context: "ok", task=task) + + assert result == "ok" + names = [s.name for s in exporter.get_finished_spans()] + assert "task.execute" not in names + + def test_dag_add_result(create_runtime_ti, mock_supervisor_comms): with DAG(dag_id="test_dag_add_result") as dag: task = PythonOperator(task_id="t", python_callable=lambda: 123)