diff --git a/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py b/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py index dd184c230ebf0..e4205d7a76f4b 100644 --- a/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py +++ b/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py @@ -78,6 +78,8 @@ LABEL_DAG_ID = "dag_id" LABEL_LOGICAL_DATE = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date" LABEL_TRY_NUMBER = "try_number" +LABEL_TASK_INSTANCE_ID = "task_instance_id" +LABEL_RUN_ID = "run_id" @attrs.define(kw_only=True) @@ -177,12 +179,14 @@ def proc( if task_id := event.get("task_id"): labels[LABEL_TASK_ID] = str(task_id) if run_id := event.get("run_id"): - labels["run_id"] = str(run_id) + labels[LABEL_RUN_ID] = str(run_id) if try_number := event.get("try_number"): labels[LABEL_TRY_NUMBER] = str(try_number) if map_index := event.get("map_index"): labels["map_index"] = str(map_index) + if ti_id := event.get("ti_id"): + labels[StackdriverTaskHandler.LABEL_TASK_INSTANCE_ID] = str(ti_id) _transport.send(record, str(msg.get("event", "")), resource=self.resource, labels=labels) return event @@ -232,8 +236,24 @@ def escape_label_value(value: str) -> str: for key, value in self.resource.labels.items(): log_filters.append(f"resource.labels.{escape_label_key(key)}={escape_label_value(value)}") - for key, value in ti_labels.items(): - log_filters.append(f"labels.{escape_label_key(key)}={escape_label_value(value)}") + ti_id_val = ti_labels.get(LABEL_TASK_INSTANCE_ID) + legacy_filters = [ + f"labels.{escape_label_key(k)}={escape_label_value(v)}" + for k, v in ti_labels.items() + if k != LABEL_TASK_INSTANCE_ID + ] + + if ti_id_val: + ti_id_filter = ( + f"labels.{escape_label_key(LABEL_TASK_INSTANCE_ID)}={escape_label_value(ti_id_val)}" + ) + if legacy_filters: + log_filters.append(f"({ti_id_filter} OR ({' AND '.join(legacy_filters)}))") + else: + log_filters.append(ti_id_filter) + else: + log_filters.extend(legacy_filters) + return "\n".join(log_filters) def read_logs( @@ -280,15 +300,23 @@ def _read_single_logs_page(self, log_filter: str, page_token: str | None = None) def _task_instance_to_labels(ti) -> dict[str, str]: """Convert a task instance to Stackdriver labels.""" - return { - LABEL_TASK_ID: ti.task_id, - LABEL_DAG_ID: ti.dag_id, - LABEL_LOGICAL_DATE: str(ti.logical_date.isoformat()) - if AIRFLOW_V_3_0_PLUS - else str(ti.execution_date.isoformat()), + labels = { + LABEL_TASK_ID: str(ti.task_id), + LABEL_DAG_ID: str(ti.dag_id), LABEL_TRY_NUMBER: str(ti.try_number), } + ti_id = getattr(ti, "id", None) + if ti_id: + labels[LABEL_TASK_INSTANCE_ID] = str(ti_id) + + if logical_date := getattr(ti, LABEL_LOGICAL_DATE, None): + labels[LABEL_LOGICAL_DATE] = str(logical_date.isoformat()) + elif run_id := getattr(ti, "run_id", None): + labels[LABEL_RUN_ID] = str(run_id) + + return labels + class StackdriverTaskHandler(logging.Handler): """ @@ -326,6 +354,7 @@ class StackdriverTaskHandler(logging.Handler): LABEL_DAG_ID = LABEL_DAG_ID LABEL_LOGICAL_DATE = LABEL_LOGICAL_DATE LABEL_TRY_NUMBER = LABEL_TRY_NUMBER + LABEL_TASK_INSTANCE_ID = LABEL_TASK_INSTANCE_ID LOG_VIEWER_BASE_URL = "https://console.cloud.google.com/logs/viewer" LOG_NAME = "Google Stackdriver" @@ -463,14 +492,15 @@ def read( @classmethod def _task_instance_to_labels(cls, ti: TaskInstance) -> dict[str, str]: - return { - cls.LABEL_TASK_ID: ti.task_id, - cls.LABEL_DAG_ID: ti.dag_id, + labels = { + cls.LABEL_TASK_ID: str(ti.task_id), + cls.LABEL_DAG_ID: str(ti.dag_id), cls.LABEL_LOGICAL_DATE: str(ti.logical_date.isoformat()) if AIRFLOW_V_3_0_PLUS else str(ti.execution_date.isoformat()), cls.LABEL_TRY_NUMBER: str(ti.try_number), } + return labels @property def log_name(self): diff --git a/providers/google/tests/unit/google/cloud/log/test_stackdriver_task_handler.py b/providers/google/tests/unit/google/cloud/log/test_stackdriver_task_handler.py index da0068ae67b97..d3c9063784cc0 100644 --- a/providers/google/tests/unit/google/cloud/log/test_stackdriver_task_handler.py +++ b/providers/google/tests/unit/google/cloud/log/test_stackdriver_task_handler.py @@ -74,15 +74,22 @@ def test_read_logs(self, mock_client, mock_get_creds_and_project_id): ) mock_get_creds_and_project_id.return_value = ("creds", "project_id") - ti = mock.MagicMock() - ti.task_id = "test_task" - ti.dag_id = "test_dag" - ti.try_number = 1 if AIRFLOW_V_3_0_PLUS: - ti.logical_date = timezone.datetime(2016, 1, 1) + from airflow.sdk.types import RuntimeTaskInstanceProtocol + + ti = mock.MagicMock(spec=RuntimeTaskInstanceProtocol) + ti.id = "test_ti_id" + ti.run_id = "run1" else: + from airflow.models.taskinstance import TaskInstance + + ti = mock.MagicMock(spec=TaskInstance) ti.execution_date = timezone.datetime(2016, 1, 1) + ti.task_id = "test_task" + ti.dag_id = "test_dag" + ti.try_number = 1 + messages, logs = self.io.read("dag_id=test_dag/run_id=run1/task_id=test_task/attempt=1.log", ti) assert len(messages) == 1 @@ -97,15 +104,22 @@ def test_read_logs_empty(self, mock_client, mock_get_creds_and_project_id): ) mock_get_creds_and_project_id.return_value = ("creds", "project_id") - ti = mock.MagicMock() - ti.task_id = "test_task" - ti.dag_id = "test_dag" - ti.try_number = 1 if AIRFLOW_V_3_0_PLUS: - ti.logical_date = timezone.datetime(2016, 1, 1) + from airflow.sdk.types import RuntimeTaskInstanceProtocol + + ti = mock.MagicMock(spec=RuntimeTaskInstanceProtocol) + ti.id = "test_ti_id" + ti.run_id = "run1" else: + from airflow.models.taskinstance import TaskInstance + + ti = mock.MagicMock(spec=TaskInstance) ti.execution_date = timezone.datetime(2016, 1, 1) + ti.task_id = "test_task" + ti.dag_id = "test_dag" + ti.try_number = 1 + messages, logs = self.io.read("test/path", ti) assert len(messages) == 1 @@ -238,6 +252,23 @@ def test_processors_fallback_to_event_labels(self, mock_transport_prop): def test_prepare_log_filter(self, mock_get_creds_and_project_id): mock_get_creds_and_project_id.return_value = ("creds", "project_id") + ti_labels = { + "task_id": "test_task", + "dag_id": "test_dag", + "try_number": "1", + "task_instance_id": "test_ti_id", + } + log_filter = self.io.prepare_log_filter(ti_labels) + + assert 'resource.type="global"' in log_filter + assert 'logName="projects/project_id/logs/airflow"' in log_filter + expected_or = '(labels.task_instance_id="test_ti_id" OR (labels.task_id="test_task" AND labels.dag_id="test_dag" AND labels.try_number="1"))' + assert expected_or in log_filter + + @mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id") + def test_prepare_log_filter_legacy(self, mock_get_creds_and_project_id): + mock_get_creds_and_project_id.return_value = ("creds", "project_id") + ti_labels = { "task_id": "test_task", "dag_id": "test_dag", @@ -249,6 +280,7 @@ def test_prepare_log_filter(self, mock_get_creds_and_project_id): assert 'logName="projects/project_id/logs/airflow"' in log_filter assert 'labels.task_id="test_task"' in log_filter assert 'labels.dag_id="test_dag"' in log_filter + assert " OR " not in log_filter @mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id") def test_prepare_log_filter_with_custom_resource(self, mock_get_creds_and_project_id): @@ -431,6 +463,7 @@ def test_should_set_labels(self, mock_client, mock_get_creds_and_project_id): "dag_id": self.DAG_ID, date_key: "2016-01-01T00:00:00+00:00", "try_number": "1", + **({"task_instance_id": str(self.ti.id)} if hasattr(self.ti, "id") else {}), } resource = Resource(type="global", labels={}) self.transport_mock.return_value.send.assert_called_once_with( @@ -456,6 +489,7 @@ def test_should_append_labels(self, mock_client, mock_get_creds_and_project_id): "dag_id": self.DAG_ID, date_key: "2016-01-01T00:00:00+00:00", "try_number": "1", + **({"task_instance_id": str(self.ti.id)} if hasattr(self.ti, "id") else {}), "product.googleapis.com/task_id": "test-value", } resource = Resource(type="global", labels={}) @@ -476,13 +510,23 @@ def test_should_read_logs_for_all_try(self, mock_client, mock_get_creds_and_proj date_label = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date" - filter_str = ( - 'resource.type="global"\n' - 'logName="projects/project_id/logs/airflow"\n' - 'labels.task_id="task_for_testing_stackdriver_task_handler"\n' - 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler"\n' - f'labels.{date_label}="2016-01-01T00:00:00+00:00"' - ) + if hasattr(self.ti, "id"): + filter_str = ( + 'resource.type="global"\n' + 'logName="projects/project_id/logs/airflow"\n' + f'(labels.task_instance_id="{str(self.ti.id)}" OR (' + 'labels.task_id="task_for_testing_stackdriver_task_handler" AND ' + 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler" AND ' + f'labels.{date_label}="2016-01-01T00:00:00+00:00"))' + ) + else: + filter_str = ( + 'resource.type="global"\n' + 'logName="projects/project_id/logs/airflow"\n' + 'labels.task_id="task_for_testing_stackdriver_task_handler"\n' + 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler"\n' + f'labels.{date_label}="2016-01-01T00:00:00+00:00"' + ) mock_client.return_value.list_log_entries.assert_called_once_with( request=ListLogEntriesRequest( resource_names=["projects/project_id"], @@ -508,13 +552,23 @@ def test_should_read_logs_for_task_with_quote(self, mock_client, mock_get_creds_ logs, metadata = stackdriver_task_handler.read(self.ti) date_label = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date" - filter_str = ( - 'resource.type="global"\n' - 'logName="projects/project_id/logs/airflow"\n' - 'labels.task_id="K\\"OT"\n' - 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler"\n' - f'labels.{date_label}="2016-01-01T00:00:00+00:00"' - ) + if hasattr(self.ti, "id"): + filter_str = ( + 'resource.type="global"\n' + 'logName="projects/project_id/logs/airflow"\n' + f'(labels.task_instance_id="{str(self.ti.id)}" OR (' + 'labels.task_id="K\\"OT" AND ' + 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler" AND ' + f'labels.{date_label}="2016-01-01T00:00:00+00:00"))' + ) + else: + filter_str = ( + 'resource.type="global"\n' + 'logName="projects/project_id/logs/airflow"\n' + 'labels.task_id="K\\"OT"\n' + 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler"\n' + f'labels.{date_label}="2016-01-01T00:00:00+00:00"' + ) mock_client.return_value.list_log_entries.assert_called_once_with( request=ListLogEntriesRequest( resource_names=["projects/project_id"], @@ -538,14 +592,25 @@ def test_should_read_logs_for_single_try(self, mock_client, mock_get_creds_and_p logs, metadata = stackdriver_task_handler.read(self.ti, 3) date_label = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date" - filter_str = ( - 'resource.type="global"\n' - 'logName="projects/project_id/logs/airflow"\n' - 'labels.task_id="task_for_testing_stackdriver_task_handler"\n' - 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler"\n' - f'labels.{date_label}="2016-01-01T00:00:00+00:00"\n' - 'labels.try_number="3"' - ) + if hasattr(self.ti, "id"): + filter_str = ( + 'resource.type="global"\n' + 'logName="projects/project_id/logs/airflow"\n' + f'(labels.task_instance_id="{str(self.ti.id)}" OR (' + 'labels.task_id="task_for_testing_stackdriver_task_handler" AND ' + 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler" AND ' + 'labels.try_number="3" AND ' + f'labels.{date_label}="2016-01-01T00:00:00+00:00"))' + ) + else: + filter_str = ( + 'resource.type="global"\n' + 'logName="projects/project_id/logs/airflow"\n' + 'labels.task_id="task_for_testing_stackdriver_task_handler"\n' + 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler"\n' + 'labels.try_number="3"\n' + f'labels.{date_label}="2016-01-01T00:00:00+00:00"' + ) mock_client.return_value.list_log_entries.assert_called_once_with( request=ListLogEntriesRequest( resource_names=["projects/project_id"], @@ -570,14 +635,25 @@ def test_should_read_logs_with_pagination(self, mock_client, mock_get_creds_and_ logs, metadata1 = stackdriver_task_handler.read(self.ti, 3) date_label = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date" - filter_str = ( - 'resource.type="global"\n' - 'logName="projects/project_id/logs/airflow"\n' - 'labels.task_id="task_for_testing_stackdriver_task_handler"\n' - 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler"\n' - f'labels.{date_label}="2016-01-01T00:00:00+00:00"\n' - 'labels.try_number="3"' - ) + if hasattr(self.ti, "id"): + filter_str = ( + 'resource.type="global"\n' + 'logName="projects/project_id/logs/airflow"\n' + f'(labels.task_instance_id="{str(self.ti.id)}" OR (' + 'labels.task_id="task_for_testing_stackdriver_task_handler" AND ' + 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler" AND ' + 'labels.try_number="3" AND ' + f'labels.{date_label}="2016-01-01T00:00:00+00:00"))' + ) + else: + filter_str = ( + 'resource.type="global"\n' + 'logName="projects/project_id/logs/airflow"\n' + 'labels.task_id="task_for_testing_stackdriver_task_handler"\n' + 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler"\n' + 'labels.try_number="3"\n' + f'labels.{date_label}="2016-01-01T00:00:00+00:00"' + ) mock_client.return_value.list_log_entries.assert_called_once_with( request=ListLogEntriesRequest( resource_names=["projects/project_id"], @@ -596,14 +672,7 @@ def test_should_read_logs_with_pagination(self, mock_client, mock_get_creds_and_ mock_client.return_value.list_log_entries.assert_called_with( request=ListLogEntriesRequest( resource_names=["projects/project_id"], - filter=( - 'resource.type="global"\n' - 'logName="projects/project_id/logs/airflow"\n' - 'labels.task_id="task_for_testing_stackdriver_task_handler"\n' - 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler"\n' - f'labels.{date_label}="2016-01-01T00:00:00+00:00"\n' - 'labels.try_number="3"' - ), + filter=filter_str, order_by="timestamp asc", page_size=1000, page_token="TOKEN1", @@ -647,16 +716,29 @@ def test_should_read_logs_with_custom_resources(self, mock_client, mock_get_cred logs, metadata = stackdriver_task_handler.read(self.ti) date_label = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date" - filter_str = ( - 'resource.type="cloud_composer_environment"\n' - 'logName="projects/project_id/logs/airflow"\n' - 'resource.labels."environment.name"="test-instance"\n' - 'resource.labels.location="europe-west-3"\n' - 'resource.labels.project_id="project_id"\n' - 'labels.task_id="task_for_testing_stackdriver_task_handler"\n' - 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler"\n' - f'labels.{date_label}="2016-01-01T00:00:00+00:00"' - ) + if hasattr(self.ti, "id"): + filter_str = ( + 'resource.type="cloud_composer_environment"\n' + 'logName="projects/project_id/logs/airflow"\n' + 'resource.labels."environment.name"="test-instance"\n' + 'resource.labels.location="europe-west-3"\n' + 'resource.labels.project_id="project_id"\n' + f'(labels.task_instance_id="{str(self.ti.id)}" OR (' + 'labels.task_id="task_for_testing_stackdriver_task_handler" AND ' + 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler" AND ' + f'labels.{date_label}="2016-01-01T00:00:00+00:00"))' + ) + else: + filter_str = ( + 'resource.type="cloud_composer_environment"\n' + 'logName="projects/project_id/logs/airflow"\n' + 'resource.labels."environment.name"="test-instance"\n' + 'resource.labels.location="europe-west-3"\n' + 'resource.labels.project_id="project_id"\n' + 'labels.task_id="task_for_testing_stackdriver_task_handler"\n' + 'labels.dag_id="dag_for_testing_stackdriver_file_task_handler"\n' + f'labels.{date_label}="2016-01-01T00:00:00+00:00"' + ) mock_client.return_value.list_log_entries.assert_called_once_with( request=ListLogEntriesRequest( resource_names=["projects/project_id"], @@ -708,14 +790,24 @@ def test_should_return_valid_external_url(self, mock_client, mock_get_creds_and_ filter_params = parsed_qs["advancedFilter"][0].splitlines() date_label = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date" - expected_filter = [ - 'resource.type="global"', - 'logName="projects/project_id/logs/airflow"', - f'labels.task_id="{self.ti.task_id}"', - f'labels.dag_id="{self.DAG_ID}"', - f'labels.{date_label}="{self.ti.logical_date.isoformat() if AIRFLOW_V_3_0_PLUS else self.ti.execution_date.isoformat()}"', - f'labels.try_number="{self.ti.try_number}"', - ] + if hasattr(self.ti, "id"): + expected_filter = [ + 'resource.type="global"', + 'logName="projects/project_id/logs/airflow"', + f'(labels.task_instance_id="{str(self.ti.id)}" OR (labels.task_id="{self.ti.task_id}" AND ' + f'labels.dag_id="{self.DAG_ID}" AND ' + f'labels.try_number="{self.ti.try_number}" AND ' + f'labels.{date_label}="{self.ti.logical_date.isoformat() if AIRFLOW_V_3_0_PLUS else self.ti.execution_date.isoformat()}"))', + ] + else: + expected_filter = [ + 'resource.type="global"', + 'logName="projects/project_id/logs/airflow"', + f'labels.task_id="{self.ti.task_id}"', + f'labels.dag_id="{self.DAG_ID}"', + f'labels.{date_label}="{self.ti.logical_date.isoformat() if AIRFLOW_V_3_0_PLUS else self.ti.execution_date.isoformat()}"', + f'labels.try_number="{self.ti.try_number}"', + ] assert set(expected_filter) == set(filter_params) @@ -741,12 +833,21 @@ def test_read_falls_back_when_cloud_logging_unavailable( ) handler = StackdriverTaskHandler() - ti = mock.MagicMock() + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.types import RuntimeTaskInstanceProtocol + + ti = mock.MagicMock(spec=RuntimeTaskInstanceProtocol) + ti.id = "test_ti_id" + ti.run_id = "run1" + else: + from airflow.models.taskinstance import TaskInstance + + ti = mock.MagicMock(spec=TaskInstance) + ti.execution_date = mock.MagicMock(isoformat=lambda: "2020-01-01T00:00:00+00:00") + ti.task_id = "t" ti.dag_id = "d" ti.try_number = 1 - ti.logical_date = mock.MagicMock(isoformat=lambda: "2020-01-01T00:00:00+00:00") - ti.execution_date = ti.logical_date with caplog.at_level(logging.ERROR): logs, metadata = handler.read(ti, try_number=1) @@ -777,12 +878,21 @@ def test_read_does_not_leak_internals_in_user_facing_message( ) handler = StackdriverTaskHandler() - ti = mock.MagicMock() + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.types import RuntimeTaskInstanceProtocol + + ti = mock.MagicMock(spec=RuntimeTaskInstanceProtocol) + ti.id = "test_ti_id" + ti.run_id = "run1" + else: + from airflow.models.taskinstance import TaskInstance + + ti = mock.MagicMock(spec=TaskInstance) + ti.execution_date = mock.MagicMock(isoformat=lambda: "2020-01-01T00:00:00+00:00") + ti.task_id = "t" ti.dag_id = "d" ti.try_number = 1 - ti.logical_date = mock.MagicMock(isoformat=lambda: "2020-01-01T00:00:00+00:00") - ti.execution_date = ti.logical_date logs, _ = handler.read(ti, try_number=1)