From ef1589e3ee356e7452ea7a58c9c67eca73443260 Mon Sep 17 00:00:00 2001 From: Josh Fell Date: Mon, 31 Jan 2022 23:14:42 -0500 Subject: [PATCH] Refactor operator links to not create ad hoc TaskInstances --- airflow/models/xcom.py | 8 ++++---- airflow/providers/amazon/aws/operators/emr.py | 7 ++++--- .../providers/google/cloud/operators/bigquery.py | 6 +++--- .../providers/google/cloud/operators/dataproc.py | 13 +++++++------ .../providers/google/cloud/operators/mlengine.py | 8 ++++---- .../microsoft/azure/operators/data_factory.py | 10 +++++++--- airflow/providers/qubole/operators/qubole.py | 8 ++++---- 7 files changed, 33 insertions(+), 27 deletions(-) diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 4402fb4aa61e7..4234d3ed2d46e 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -220,7 +220,7 @@ def get_one( @classmethod def get_one( cls, - execution_date: pendulum.DateTime, + execution_date: datetime.datetime, key: Optional[str] = None, task_id: Optional[str] = None, dag_id: Optional[str] = None, @@ -233,7 +233,7 @@ def get_one( @provide_session def get_one( cls, - execution_date: Optional[pendulum.DateTime] = None, + execution_date: Optional[datetime.datetime] = None, key: Optional[str] = None, task_id: Optional[Union[str, Iterable[str]]] = None, dag_id: Optional[Union[str, Iterable[str]]] = None, @@ -314,7 +314,7 @@ def get_many( @classmethod def get_many( cls, - execution_date: pendulum.DateTime, + execution_date: datetime.datetime, key: Optional[str] = None, task_ids: Union[str, Iterable[str], None] = None, dag_ids: Union[str, Iterable[str], None] = None, @@ -328,7 +328,7 @@ def get_many( @provide_session def get_many( cls, - execution_date: Optional[pendulum.DateTime] = None, + execution_date: Optional[datetime.datetime] = None, key: Optional[str] = None, task_ids: Optional[Union[str, Iterable[str]]] = None, dag_ids: Optional[Union[str, Iterable[str]]] = None, diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 29bac5cb65a32..6ce5dc0cd7c67 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -22,7 +22,7 @@ from uuid import uuid4 from airflow.exceptions import AirflowException -from airflow.models import BaseOperator, BaseOperatorLink, TaskInstance +from airflow.models import BaseOperator, BaseOperatorLink, XCom from airflow.providers.amazon.aws.hooks.emr import EmrHook if TYPE_CHECKING: @@ -238,8 +238,9 @@ def get_link(self, operator: BaseOperator, dttm: datetime) -> str: :param dttm: datetime :return: url link """ - ti = TaskInstance(task=operator, execution_date=dttm) - flow_id = ti.xcom_pull(task_ids=operator.task_id) + flow_id = XCom.get_one( + key="return_value", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm + ) return ( f'https://console.aws.amazon.com/elasticmapreduce/home#cluster-details:{flow_id}' if flow_id diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 89daa17a446d2..abd2b8b446a7a 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -32,7 +32,6 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator, BaseOperatorLink -from airflow.models.taskinstance import TaskInstance from airflow.models.xcom import XCom from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob @@ -84,8 +83,9 @@ def name(self) -> str: return f'BigQuery Console #{self.index + 1}' def get_link(self, operator: BaseOperator, dttm: datetime): - ti = TaskInstance(task=operator, execution_date=dttm) - job_ids = ti.xcom_pull(task_ids=operator.task_id, key='job_id') + job_ids = XCom.get_one( + key='job_id', dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm + ) if not job_ids: return None if len(job_ids) < self.index: diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 6a86d1c4cd789..9d75ba798f10e 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -36,8 +36,7 @@ from google.protobuf.field_mask_pb2 import FieldMask from airflow.exceptions import AirflowException -from airflow.models import BaseOperator, BaseOperatorLink -from airflow.models.taskinstance import TaskInstance +from airflow.models import BaseOperator, BaseOperatorLink, XCom from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProcJobBuilder from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.utils import timezone @@ -59,8 +58,9 @@ class DataprocJobLink(BaseOperatorLink): name = "Dataproc Job" def get_link(self, operator, dttm): - ti = TaskInstance(task=operator, execution_date=dttm) - job_conf = ti.xcom_pull(task_ids=operator.task_id, key="job_conf") + job_conf = XCom.get_one( + key="job_conf", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm + ) return ( DATAPROC_JOB_LOG_LINK.format( job_id=job_conf["job_id"], @@ -78,8 +78,9 @@ class DataprocClusterLink(BaseOperatorLink): name = "Dataproc Cluster" def get_link(self, operator, dttm): - ti = TaskInstance(task=operator, execution_date=dttm) - cluster_conf = ti.xcom_pull(task_ids=operator.task_id, key="cluster_conf") + cluster_conf = XCom.get_one( + key="cluster_conf", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm + ) return ( DATAPROC_CLUSTER_LINK.format( cluster_name=cluster_conf["cluster_name"], diff --git a/airflow/providers/google/cloud/operators/mlengine.py b/airflow/providers/google/cloud/operators/mlengine.py index ae038dc9255ce..f2784a08b7fb3 100644 --- a/airflow/providers/google/cloud/operators/mlengine.py +++ b/airflow/providers/google/cloud/operators/mlengine.py @@ -22,8 +22,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from airflow.exceptions import AirflowException -from airflow.models import BaseOperator, BaseOperatorLink -from airflow.models.taskinstance import TaskInstance +from airflow.models import BaseOperator, BaseOperatorLink, XCom from airflow.providers.google.cloud.hooks.mlengine import MLEngineHook if TYPE_CHECKING: @@ -980,8 +979,9 @@ class AIPlatformConsoleLink(BaseOperatorLink): name = "AI Platform Console" def get_link(self, operator, dttm): - task_instance = TaskInstance(task=operator, execution_date=dttm) - gcp_metadata_dict = task_instance.xcom_pull(task_ids=operator.task_id, key="gcp_metadata") + gcp_metadata_dict = XCom.get_one( + key="gcp_metadata", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm + ) if not gcp_metadata_dict: return '' job_id = gcp_metadata_dict['job_id'] diff --git a/airflow/providers/microsoft/azure/operators/data_factory.py b/airflow/providers/microsoft/azure/operators/data_factory.py index 0df599b53c306..5142276efe3bf 100644 --- a/airflow/providers/microsoft/azure/operators/data_factory.py +++ b/airflow/providers/microsoft/azure/operators/data_factory.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence from airflow.hooks.base import BaseHook -from airflow.models import BaseOperator, BaseOperatorLink, TaskInstance +from airflow.models import BaseOperator, BaseOperatorLink, XCom from airflow.providers.microsoft.azure.hooks.data_factory import ( AzureDataFactoryHook, AzureDataFactoryPipelineRunException, @@ -35,8 +35,12 @@ class AzureDataFactoryPipelineRunLink(BaseOperatorLink): name = "Monitor Pipeline Run" def get_link(self, operator, dttm): - ti = TaskInstance(task=operator, execution_date=dttm) - run_id = ti.xcom_pull(task_ids=operator.task_id, key="run_id") + run_id = XCom.get_one( + key="run_id", + dag_id=operator.dag.dag_id, + task_id=operator.task_id, + execution_date=dttm, + ) conn = BaseHook.get_connection(operator.azure_data_factory_conn_id) subscription_id = conn.extra_dejson["extra__azure_data_factory__subscriptionId"] diff --git a/airflow/providers/qubole/operators/qubole.py b/airflow/providers/qubole/operators/qubole.py index 0e7adf575059f..c8b733598d517 100644 --- a/airflow/providers/qubole/operators/qubole.py +++ b/airflow/providers/qubole/operators/qubole.py @@ -21,8 +21,7 @@ from typing import TYPE_CHECKING, Optional, Sequence from airflow.hooks.base import BaseHook -from airflow.models import BaseOperator, BaseOperatorLink -from airflow.models.taskinstance import TaskInstance +from airflow.models import BaseOperator, BaseOperatorLink, XCom from airflow.providers.qubole.hooks.qubole import ( COMMAND_ARGS, HYPHEN_ARGS, @@ -48,7 +47,6 @@ def get_link(self, operator: BaseOperator, dttm: datetime) -> str: :param dttm: datetime :return: url link """ - ti = TaskInstance(task=operator, execution_date=dttm) conn = BaseHook.get_connection( getattr(operator, "qubole_conn_id", None) or operator.kwargs['qubole_conn_id'] # type: ignore[attr-defined] @@ -57,7 +55,9 @@ def get_link(self, operator: BaseOperator, dttm: datetime) -> str: host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host) else: host = 'https://api.qubole.com/v2/analyze?command_id=' - qds_command_id = ti.xcom_pull(task_ids=operator.task_id, key='qbol_cmd_id') + qds_command_id = XCom.get_one( + key='qbol_cmd_id', dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm + ) url = host + str(qds_command_id) if qds_command_id else '' return url