Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,37 @@ def _executor_initializer():
This function must be picklable, so it cannot be defined as an inner method or local function.

Reconfigures the ORM engine to prevent issues that arise when multiple processes interact with
the Airflow database.
the Airflow database, and re-initializes ``Stats`` so that metrics emitted from worker
processes (e.g. ``ol.event.size.*`` from ``_emit_manual_state_change_event``) are routed to
the configured statsd backend instead of being silently dropped by ``NoStatsLogger`` — the
parent's ``Stats.initialize(...)`` call from scheduler startup does not propagate across the
spawn boundary.
"""
# This initializer is used only on the scheduler
# We can configure_orm regardless of the Airflow version, as DB access is always allowed from scheduler.
settings.configure_orm()
try:
from airflow.observability.metrics import stats_utils

Stats.initialize(factory=stats_utils.get_stats_factory(Stats))
except ImportError:
# ``stats_utils`` lives under ``airflow.observability.metrics`` in current Airflow; if the
# import path changes or is unavailable, fall through silently — gauge calls will simply
# land on ``NoStatsLogger`` as before, which is no worse than current behavior.
pass


def _emit_manual_state_change_event(adapter_method, stats_key, **kwargs):
"""
Emit an OL event via the given adapter method and record its serialized size.

Module-level so it is picklable across the ProcessPoolExecutor boundary used by
`_on_task_instance_manual_state_change` for scheduler-side "task state changed
externally" emissions.
"""
event = adapter_method(**kwargs)
Stats.gauge(stats_key, len(Serde.to_json(event).encode("utf-8")))
Comment thread
mobuchowski marked this conversation as resolved.
return event


class OpenLineageListener:
Expand Down Expand Up @@ -653,6 +679,17 @@ def _on_task_instance_manual_state_change(
ti_state: TaskInstanceState,
error: None | str | BaseException = None,
) -> None:
"""
Emit an OL event from the scheduler when a TI transitions externally.

This path is only reached on the scheduler (``process_executor_events ->
handle_failure``, or manual UI/API state changes). Emission is routed through
the same ``ProcessPoolExecutor`` the DAG-run listeners use rather than through
``_fork_execute``: the pool's ``_executor_initializer`` rebuilds the ORM once
per worker, so the child never shares a pooled Postgres SSL connection with
the scheduler, and bursts of external-state-change events no longer produce a
fork-per-event.
"""
self.log.debug("`_on_task_instance_manual_state_change` was called with state: `%s`.", ti_state)
end_date = timezone.utcnow()

Expand All @@ -674,45 +711,64 @@ def _on_task_instance_manual_state_change(
)
return

@print_warning(self.log)
def on_state_change():
date = dagrun.logical_date or dagrun.run_after
parent_run_id = self.adapter.build_dag_run_id(
dag_id=ti.dag_id,
logical_date=date,
clear_number=dagrun.clear_number,
)
try:
if not self.executor:
self.log.debug("Executor has not started before `_on_task_instance_manual_state_change`")
return

if ti_state == TaskInstanceState.FAILED:
adapter_method = self.adapter.fail_task
event_type = RunState.FAIL.value.lower()
elif ti_state in (TaskInstanceState.SUCCESS, TaskInstanceState.SKIPPED):
adapter_method = self.adapter.complete_task
event_type = RunState.COMPLETE.value.lower()
else:
raise ValueError(f"Unsupported ti_state: `{ti_state}`.")

# Extract primitives from live ORM objects in the parent (scheduler)
# before crossing the pool boundary. Passing ORM objects through the pool
# pickler loses TaskGroup attributes and crashes event emission -- see
# the equivalent note in `on_dag_run_running` (listener.py ~868).
date = dagrun.logical_date or dagrun.run_after
task_uuid = self.adapter.build_task_instance_run_id(
dag_id=ti.dag_id,
task_id=ti.task_id,
try_number=ti.try_number,
logical_date=date,
map_index=ti.map_index,
)
parent_run_id = self.adapter.build_dag_run_id(
dag_id=ti.dag_id,
logical_date=date,
clear_number=dagrun.clear_number,
)

data_interval_start = dagrun.data_interval_start
# Mirror the pattern used in the other listener call sites: convert
# `datetime` to ISO-8601 string, but preserve any non-`datetime`
# value as-is in case a duck-typed caller already passed a string.
data_interval_start: str | datetime | None = dagrun.data_interval_start
if isinstance(data_interval_start, datetime):
data_interval_start = data_interval_start.isoformat()
data_interval_end = dagrun.data_interval_end
data_interval_end: str | datetime | None = dagrun.data_interval_end
if isinstance(data_interval_end, datetime):
data_interval_end = data_interval_end.isoformat()

dag_tags, owners, doc, doc_type = None, None, None, None
airflow_run_facet = {}
dag_tags: list | None = None
owners: list[str] | None = None
doc: str | None = None
doc_type: str | None = None
airflow_run_facet: dict = {}
if task: # on scheduler, we should have access to task
doc, doc_type = get_task_documentation(task)
dag = getattr(task, "dag")
if dag:
if not doc:
doc, doc_type = get_dag_documentation(dag)

dag_tags = dag.tags
owners = [x.strip() for x in (task if task.owner != "airflow" else dag).owner.split(",")]

airflow_run_facet = get_airflow_run_facet(dagrun, dag, ti, task, task_uuid)

adapter_kwargs = {
adapter_kwargs: dict = {
"run_id": task_uuid,
"job_name": get_job_name(ti),
"end_time": end_date.isoformat(),
Expand All @@ -733,23 +789,21 @@ def on_state_change():
**get_airflow_debug_facet(),
},
}

if ti_state == TaskInstanceState.FAILED:
event_type = RunState.FAIL.value.lower()
redacted_event = self.adapter.fail_task(**adapter_kwargs, error=error)
elif ti_state in (TaskInstanceState.SUCCESS, TaskInstanceState.SKIPPED):
event_type = RunState.COMPLETE.value.lower()
redacted_event = self.adapter.complete_task(**adapter_kwargs)
else:
raise ValueError(f"Unsupported ti_state: `{ti_state}`.")
adapter_kwargs["error"] = error

operator_name = ti.operator.lower()
Stats.gauge(
operator_name = (ti.operator or "unknown").lower()
self.submit_callable(
_emit_manual_state_change_event,
adapter_method,
f"ol.event.size.{event_type}.{operator_name}",
len(Serde.to_json(redacted_event).encode("utf-8")),
**adapter_kwargs,
)
except BaseException as e:
self.log.warning(
"OpenLineage received exception in method `_on_task_instance_manual_state_change`",
exc_info=e,
)

self._execute(on_state_change, "on_state_change", use_fork=True)

def _execute(self, callable, callable_name: str, use_fork: bool = False):
if use_fork:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,23 @@ def regular_call(self, callable, callable_name, use_fork):
callable()


def direct_submit_call(self, callable, *args, **kwargs):
Comment thread
mobuchowski marked this conversation as resolved.
"""Synchronous stand-in for ``OpenLineageListener.submit_callable``.

Bypasses the ``ProcessPoolExecutor`` so tests can assert against mocked
adapter methods without hitting pickling of ``unittest.mock.Mock``.
When the submitted callable is ``_emit_manual_state_change_event``, skip
its ``Stats.gauge`` side effect (which would try to ``Serde.to_json`` a
``MagicMock`` return value) and invoke the adapter method directly.
"""
from airflow.providers.openlineage.plugins.listener import _emit_manual_state_change_event

if callable is _emit_manual_state_change_event:
adapter_method, _stats_key, *_ = args
return adapter_method(**kwargs)
return callable(*args, **kwargs)


class MockExecutor:
def __init__(self, *args, **kwargs):
self.submitted = False
Expand Down Expand Up @@ -1457,13 +1474,15 @@ def test_adapter_fail_task_is_called_with_dag_description_when_task_doc_is_empty
assert listener.adapter.fail_task.call_args.kwargs["job_description"] == "Test DAG Description"
assert listener.adapter.fail_task.call_args.kwargs["job_description_type"] == "text/plain"

@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._fork_execute")
@mock.patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit")
@mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True)
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_debug_facet")
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
@mock.patch("airflow.providers.openlineage.plugins.listener.get_task_parent_run_facet")
@mock.patch(
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call
"airflow.providers.openlineage.plugins.listener.OpenLineageListener.submit_callable",
new=direct_submit_call,
)
def test_adapter_fail_task_is_called_with_proper_arguments_for_db_task_instance_model(
self,
Expand All @@ -1472,6 +1491,7 @@ def test_adapter_fail_task_is_called_with_proper_arguments_for_db_task_instance_
mock_debug_facet,
mock_debug_mode,
mock_emit,
mock_fork_execute,
time_machine,
):
"""Tests that the 'fail_task' method of the OpenLineageAdapter is invoked with the correct arguments.
Expand All @@ -1482,6 +1502,7 @@ def test_adapter_fail_task_is_called_with_proper_arguments_for_db_task_instance_
time_machine.move_to(timezone.datetime(2023, 1, 3, 13, 1, 1), tick=False)

listener, task_instance = self._create_listener_and_task_instance(runtime_ti=False)
listener._executor = MagicMock() # satisfy `if not self.executor` guard
mock_get_airflow_run_facet.return_value = {"airflow": 3}
mock_get_task_parent_run_facet.return_value = {"parent": 4}
mock_debug_facet.return_value = {"debug": "packages"}
Expand Down Expand Up @@ -1513,6 +1534,8 @@ def test_adapter_fail_task_is_called_with_proper_arguments_for_db_task_instance_
error=err,
)
listener.adapter.fail_task.assert_called_once_with(**expected_args)
# Regression guard: manual state-change emission must not go through _fork_execute.
mock_fork_execute.assert_not_called()

expected_args["run_id"] = "9d3b14f7-de91-40b6-aeef-e887e2c7673e"
adapter = OpenLineageAdapter()
Expand Down Expand Up @@ -1644,15 +1667,23 @@ def test_adapter_complete_task_is_called_with_dag_description_when_task_doc_is_e
assert listener.adapter.complete_task.call_args.kwargs["job_description"] == "Test DAG Description"
assert listener.adapter.complete_task.call_args.kwargs["job_description_type"] == "text/plain"

@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._fork_execute")
@mock.patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit")
@mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True)
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_debug_facet")
@mock.patch("airflow.providers.openlineage.plugins.listener.get_task_parent_run_facet")
@mock.patch(
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call
"airflow.providers.openlineage.plugins.listener.OpenLineageListener.submit_callable",
new=direct_submit_call,
)
def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_instance_model(
self, mock_get_task_parent_run_facet, mock_debug_facet, mock_debug_mode, mock_emit, time_machine
self,
mock_get_task_parent_run_facet,
mock_debug_facet,
mock_debug_mode,
mock_emit,
mock_fork_execute,
time_machine,
):
"""Tests that the 'complete_task' method of the OpenLineageAdapter is called with the correct arguments.

Expand All @@ -1662,6 +1693,7 @@ def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_insta
time_machine.move_to(timezone.datetime(2023, 1, 3, 13, 1, 1), tick=False)

listener, task_instance = self._create_listener_and_task_instance(runtime_ti=False)
listener._executor = MagicMock() # satisfy `if not self.executor` guard
delattr(task_instance, "task") # Test api server path, where task is not available
mock_get_task_parent_run_facet.return_value = {"parent": 4}
mock_debug_facet.return_value = {"debug": "packages"}
Expand Down Expand Up @@ -1691,6 +1723,8 @@ def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_insta
},
)
assert calls[0][1] == expected_args
# Regression guard: manual state-change emission must not go through _fork_execute.
mock_fork_execute.assert_not_called()

expected_args["run_id"] = "9d3b14f7-de91-40b6-aeef-e887e2c7673e"
adapter = OpenLineageAdapter()
Expand Down Expand Up @@ -1851,15 +1885,23 @@ def test_listener_on_task_instance_skipped_do_not_call_adapter_when_disabled_ope
listener.extractor_manager.extract_metadata.assert_not_called()
listener.adapter.complete_task.assert_not_called()

@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._fork_execute")
@mock.patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit")
@mock.patch("airflow.providers.openlineage.conf.debug_mode", return_value=True)
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_debug_facet")
@mock.patch("airflow.providers.openlineage.plugins.listener.get_task_parent_run_facet")
@mock.patch(
"airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute", new=regular_call
"airflow.providers.openlineage.plugins.listener.OpenLineageListener.submit_callable",
new=direct_submit_call,
)
def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_instance_model_on_skip(
self, mock_get_task_parent_run_facet, mock_debug_facet, mock_debug_mode, mock_emit, time_machine
self,
mock_get_task_parent_run_facet,
mock_debug_facet,
mock_debug_mode,
mock_emit,
mock_fork_execute,
time_machine,
):
"""Tests that the 'complete_task' method of the OpenLineageAdapter is called with the correct arguments.

Expand All @@ -1869,6 +1911,7 @@ def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_insta
time_machine.move_to(timezone.datetime(2023, 1, 3, 13, 1, 1), tick=False)

listener, task_instance = self._create_listener_and_task_instance(runtime_ti=False)
listener._executor = MagicMock() # satisfy `if not self.executor` guard
delattr(task_instance, "task") # Test api server path, where task is not available
mock_get_task_parent_run_facet.return_value = {"parent": 4}
mock_debug_facet.return_value = {"debug": "packages"}
Expand Down Expand Up @@ -1898,6 +1941,8 @@ def test_adapter_complete_task_is_called_with_proper_arguments_for_db_task_insta
},
)
assert calls[0][1] == expected_args
# Regression guard: manual state-change emission must not go through _fork_execute.
mock_fork_execute.assert_not_called()

expected_args["run_id"] = "9d3b14f7-de91-40b6-aeef-e887e2c7673e"
adapter = OpenLineageAdapter()
Expand Down
Loading