Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,6 +2064,43 @@ def _send_error_email_notification(
log.exception("Failed to send email notification")


@detail_span("task.execute")
def _run_execute_callable(
context: Context,
execute: Callable[..., Any] | functools.partial[Any],
task: BaseOperator,
) -> Any:
"""
Run the task's execute callable, applying the execution timeout if one is set.

The contextvars snapshot is taken here, after the ``task.execute`` span is
current, so spans the operator emits during ``execute`` nest under it rather
than under the caller. ``ExecutorSafeguard``'s tracker is set into that copy
so the operator's ``execute`` passes the safeguard check, while the copy keeps
the change from leaking into the surrounding context.
"""
ctx = contextvars.copy_context()
ctx.run(ExecutorSafeguard.tracker.set, task)
if task.execution_timeout:
from airflow.sdk.execution_time.timeout import timeout

# TODO: handle timeout in case of deferral
timeout_seconds = task.execution_timeout.total_seconds()
try:
# It's possible we're already timed out, so fast-fail if true
if timeout_seconds <= 0:
raise AirflowTaskTimeout()
# Run task in timeout wrapper
with timeout(timeout_seconds):
result = ctx.run(execute, context=context)
except AirflowTaskTimeout:
task.on_kill()
raise
else:
result = ctx.run(execute, context=context)
return result


@detail_span("_execute_task")
def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):
"""Execute Task (optionally with a Timeout) and push Xcom results."""
Expand All @@ -2087,10 +2124,6 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):
assert isinstance(kwargs, dict)
execute = functools.partial(task.resume_execution, next_method=next_method, next_kwargs=kwargs)

ctx = contextvars.copy_context()
# Populate the context var so ExecutorSafeguard doesn't complain
ctx.run(ExecutorSafeguard.tracker.set, task)

# Export context in os.environ to make it available for operators to use.
airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
os.environ.update(airflow_context_vars)
Expand All @@ -2106,23 +2139,7 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):

log.info("::endgroup::")

if task.execution_timeout:
from airflow.sdk.execution_time.timeout import timeout

# TODO: handle timeout in case of deferral
timeout_seconds = task.execution_timeout.total_seconds()
try:
# It's possible we're already timed out, so fast-fail if true
if timeout_seconds <= 0:
raise AirflowTaskTimeout()
# Run task in timeout wrapper
with timeout(timeout_seconds):
result = ctx.run(execute, context=context)
except AirflowTaskTimeout:
task.on_kill()
raise
else:
result = ctx.run(execute, context=context)
result = _run_execute_callable(context, execute, task)

if (post_execute_hook := task._post_execute_hook) is not None:
create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result)
Expand Down
140 changes: 140 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import contextlib
import contextvars
import functools
import json
import os
Expand Down Expand Up @@ -73,6 +74,7 @@
TaskInstanceState,
TIRunContext,
)
from airflow.sdk.bases.operator import ExecutorSafeguard
from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.definitions._internal.types import NOTSET, SET_DURING_EXECUTION, is_arg_set
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, AssetUriRef, Dataset, Model
Expand Down Expand Up @@ -167,6 +169,7 @@
_make_task_span,
_push_xcom_if_needed,
_register_deserialization_allowed_classes,
_run_execute_callable,
_serialize_outlet_events,
_xcom_push,
detail_span,
Expand Down Expand Up @@ -5597,6 +5600,143 @@ def test_exception_in_context_manager_propagates(self):
raise ValueError("boom")


class TestRunExecuteCallable:
"""Tests for ``_run_execute_callable``.

It runs the task's execute callable inside an isolated contextvars copy (with
the ExecutorSafeguard tracker set), applies the execution timeout when one is
configured, and wraps the call in a ``task.execute`` detail span.
"""

@pytest.fixture(autouse=True)
def _sampled_carrier_provider(self):
"""Make new_dagrun_trace_carrier produce a SAMPLED carrier (see TestDetailSpan)."""
provider = TracerProvider()
with mock.patch(
"airflow._shared.observability.traces.trace.get_tracer_provider",
return_value=provider,
):
yield

@staticmethod
def _make_task(execution_timeout=None):
task = mock.MagicMock(spec=BaseOperator)
task.execution_timeout = execution_timeout
return task

def test_runs_in_isolated_context_with_safeguard_tracker_set(self):
"""The callable runs in an internal context copy that has the safeguard tracker set and does not leak."""
var = contextvars.ContextVar("marker")
var.set("outer")
task = self._make_task()
seen = {}

def execute(context):
var.set("inner")
seen["tracker"] = ExecutorSafeguard.tracker.get(None)
return context["value"] * 2

result = _run_execute_callable(context={"value": 21}, execute=execute, task=task)

assert result == 42
# The safeguard tracker is set to the task inside the copy used to run execute.
assert seen["tracker"] is task
# The mutation happened inside the copy, so it does not leak to the caller's context.
assert var.get() == "outer"
# The .set was confined to the copy, so the tracker never leaked to the caller's context.
assert ExecutorSafeguard.tracker.get(None) is not task
task.on_kill.assert_not_called()

def test_applies_execution_timeout(self):
"""When a timeout is set and the callable overruns, AirflowTaskTimeout is raised and on_kill is called."""
task = self._make_task(execution_timeout=timedelta(milliseconds=10))

def execute(context):
time.sleep(2)

with pytest.raises(AirflowTaskTimeout):
_run_execute_callable(context={}, execute=execute, task=task)

task.on_kill.assert_called_once()

def test_fast_fails_when_timeout_already_elapsed(self):
"""A non-positive timeout fast-fails before running the callable and still calls on_kill."""
task = self._make_task(execution_timeout=timedelta(seconds=-1))
execute = mock.MagicMock()

with pytest.raises(AirflowTaskTimeout):
_run_execute_callable(context={}, execute=execute, task=task)

execute.assert_not_called()
task.on_kill.assert_called_once()

def test_emits_task_execute_span_at_detail_level_2(self):
"""At detail level 2, running the callable produces a recorded ``task.execute`` span."""
exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
t = provider.get_tracer("test")
carrier = new_dagrun_trace_carrier(task_span_detail_level=2)
parent_ctx = TraceContextTextMapPropagator().extract(carrier)

task = self._make_task()

with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t):
with t.start_as_current_span("parent", context=parent_ctx):
result = _run_execute_callable(context={}, execute=lambda context: "ok", task=task)

assert result == "ok"
names = [s.name for s in exporter.get_finished_spans()]
assert "task.execute" in names

def test_operator_child_spans_nest_under_task_execute(self):
"""Spans the operator emits during execute nest under ``task.execute``, not its caller.

The contextvars snapshot is taken inside ``_run_execute_callable`` after the
``task.execute`` span is current, so a span started during execute parents to
``task.execute`` rather than to the surrounding span.
"""
exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
t = provider.get_tracer("test")
carrier = new_dagrun_trace_carrier(task_span_detail_level=2)
parent_ctx = TraceContextTextMapPropagator().extract(carrier)

task = self._make_task()

def execute(context):
with t.start_as_current_span("operator_child"):
return "ok"

with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t):
with t.start_as_current_span("parent", context=parent_ctx):
result = _run_execute_callable(context={}, execute=execute, task=task)

assert result == "ok"
spans = {s.name: s for s in exporter.get_finished_spans()}
assert spans["operator_child"].parent.span_id == spans["task.execute"].context.span_id

def test_no_task_execute_span_at_detail_level_1(self):
"""At detail level 1, no ``task.execute`` span is recorded but the callable still runs."""
exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
t = provider.get_tracer("test")
carrier = new_dagrun_trace_carrier(task_span_detail_level=1)
parent_ctx = TraceContextTextMapPropagator().extract(carrier)

task = self._make_task()

with mock.patch("airflow.sdk.execution_time.task_runner.tracer", t):
with t.start_as_current_span("parent", context=parent_ctx):
result = _run_execute_callable(context={}, execute=lambda context: "ok", task=task)

assert result == "ok"
names = [s.name for s in exporter.get_finished_spans()]
assert "task.execute" not in names


def test_dag_add_result(create_runtime_ti, mock_supervisor_comms):
with DAG(dag_id="test_dag_add_result") as dag:
task = PythonOperator(task_id="t", python_callable=lambda: 123)
Expand Down