Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
LABEL_DAG_ID = "dag_id"
LABEL_LOGICAL_DATE = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date"
LABEL_TRY_NUMBER = "try_number"
LABEL_TASK_INSTANCE_ID = "task_instance_id"
LABEL_RUN_ID = "run_id"


@attrs.define(kw_only=True)
Expand Down Expand Up @@ -177,12 +179,14 @@ def proc(
if task_id := event.get("task_id"):
labels[LABEL_TASK_ID] = str(task_id)
if run_id := event.get("run_id"):
labels["run_id"] = str(run_id)
labels[LABEL_RUN_ID] = str(run_id)
if try_number := event.get("try_number"):
labels[LABEL_TRY_NUMBER] = str(try_number)
if map_index := event.get("map_index"):
labels["map_index"] = str(map_index)

if ti_id := event.get("ti_id"):
labels[StackdriverTaskHandler.LABEL_TASK_INSTANCE_ID] = str(ti_id)
_transport.send(record, str(msg.get("event", "")), resource=self.resource, labels=labels)
return event

Expand Down Expand Up @@ -232,8 +236,24 @@ def escape_label_value(value: str) -> str:
for key, value in self.resource.labels.items():
log_filters.append(f"resource.labels.{escape_label_key(key)}={escape_label_value(value)}")

for key, value in ti_labels.items():
log_filters.append(f"labels.{escape_label_key(key)}={escape_label_value(value)}")
ti_id_val = ti_labels.get(LABEL_TASK_INSTANCE_ID)
legacy_filters = [
f"labels.{escape_label_key(k)}={escape_label_value(v)}"
for k, v in ti_labels.items()
if k != LABEL_TASK_INSTANCE_ID
]

if ti_id_val:
ti_id_filter = (
f"labels.{escape_label_key(LABEL_TASK_INSTANCE_ID)}={escape_label_value(ti_id_val)}"
)
if legacy_filters:
log_filters.append(f"({ti_id_filter} OR ({' AND '.join(legacy_filters)}))")
else:
log_filters.append(ti_id_filter)
else:
log_filters.extend(legacy_filters)

return "\n".join(log_filters)

def read_logs(
Expand Down Expand Up @@ -280,15 +300,23 @@ def _read_single_logs_page(self, log_filter: str, page_token: str | None = None)

def _task_instance_to_labels(ti) -> dict[str, str]:
"""Convert a task instance to Stackdriver labels."""
return {
LABEL_TASK_ID: ti.task_id,
LABEL_DAG_ID: ti.dag_id,
LABEL_LOGICAL_DATE: str(ti.logical_date.isoformat())
if AIRFLOW_V_3_0_PLUS
else str(ti.execution_date.isoformat()),
labels = {
LABEL_TASK_ID: str(ti.task_id),
LABEL_DAG_ID: str(ti.dag_id),
LABEL_TRY_NUMBER: str(ti.try_number),
}

ti_id = getattr(ti, "id", None)
if ti_id:
labels[LABEL_TASK_INSTANCE_ID] = str(ti_id)

if logical_date := getattr(ti, LABEL_LOGICAL_DATE, None):
labels[LABEL_LOGICAL_DATE] = str(logical_date.isoformat())
elif run_id := getattr(ti, "run_id", None):
labels[LABEL_RUN_ID] = str(run_id)

return labels


class StackdriverTaskHandler(logging.Handler):
"""
Expand Down Expand Up @@ -326,6 +354,7 @@ class StackdriverTaskHandler(logging.Handler):
LABEL_DAG_ID = LABEL_DAG_ID
LABEL_LOGICAL_DATE = LABEL_LOGICAL_DATE
LABEL_TRY_NUMBER = LABEL_TRY_NUMBER
LABEL_TASK_INSTANCE_ID = LABEL_TASK_INSTANCE_ID
LOG_VIEWER_BASE_URL = "https://console.cloud.google.com/logs/viewer"
LOG_NAME = "Google Stackdriver"

Expand Down Expand Up @@ -463,14 +492,15 @@ def read(

@classmethod
def _task_instance_to_labels(cls, ti: TaskInstance) -> dict[str, str]:
return {
cls.LABEL_TASK_ID: ti.task_id,
cls.LABEL_DAG_ID: ti.dag_id,
labels = {
cls.LABEL_TASK_ID: str(ti.task_id),
cls.LABEL_DAG_ID: str(ti.dag_id),
cls.LABEL_LOGICAL_DATE: str(ti.logical_date.isoformat())
if AIRFLOW_V_3_0_PLUS
else str(ti.execution_date.isoformat()),
cls.LABEL_TRY_NUMBER: str(ti.try_number),
}
return labels

@property
def log_name(self):
Expand Down
Loading
Loading