diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py index c98641b0c6bd6..48ddb4b0c3197 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py @@ -423,19 +423,22 @@ def get_task_instance(self, *, session: Session) -> TaskInstance: """Get the task instance for the current trigger (Airflow 2.x compatibility).""" from sqlalchemy import select + ti = self.task_instance + if ti is None: + raise RuntimeError("task_instance is not set on the trigger") query = select(TaskInstance).where( - TaskInstance.dag_id == self.task_instance.dag_id, - TaskInstance.task_id == self.task_instance.task_id, - TaskInstance.run_id == self.task_instance.run_id, - TaskInstance.map_index == self.task_instance.map_index, + TaskInstance.dag_id == ti.dag_id, + TaskInstance.task_id == ti.task_id, + TaskInstance.run_id == ti.run_id, + TaskInstance.map_index == ti.map_index, ) task_instance = session.scalars(query).one_or_none() if task_instance is None: raise ValueError( - f"TaskInstance with dag_id: {self.task_instance.dag_id}, " - f"task_id: {self.task_instance.task_id}, " - f"run_id: {self.task_instance.run_id} and " - f"map_index: {self.task_instance.map_index} is not found" + f"TaskInstance with dag_id: {ti.dag_id}, " + f"task_id: {ti.task_id}, " + f"run_id: {ti.run_id} and " + f"map_index: {ti.map_index} is not found" ) return task_instance diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py index 0599b0519a6ea..4ee0b90c51f8c 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py @@ -440,21 +440,24 @@ def pod_manager(self) -> AsyncPodManager: @provide_session def get_task_instance(self, *, session: Session) -> TaskInstance: """Get the task instance for this trigger from the database (Airflow 2.x only).""" + ti = self.task_instance + if ti is None: + raise RuntimeError("task_instance is not set on the trigger") task_instance = session.scalar( select(TaskInstance).where( - TaskInstance.dag_id == self.task_instance.dag_id, - TaskInstance.task_id == self.task_instance.task_id, - TaskInstance.run_id == self.task_instance.run_id, - TaskInstance.map_index == self.task_instance.map_index, + TaskInstance.dag_id == ti.dag_id, + TaskInstance.task_id == ti.task_id, + TaskInstance.run_id == ti.run_id, + TaskInstance.map_index == ti.map_index, ) ) if task_instance is None: raise AirflowException( "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found", - self.task_instance.dag_id, - self.task_instance.task_id, - self.task_instance.run_id, - self.task_instance.map_index, + ti.dag_id, + ti.task_id, + ti.run_id, + ti.map_index, ) return task_instance diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py index 659cc9f8bf335..e11059dfbb897 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py @@ -120,42 +120,48 @@ async def on_kill(self) -> None: @provide_session def get_task_instance(self, *, session: Session) -> TaskInstance: + ti = self.task_instance + if ti is None: + raise RuntimeError("task_instance is not set on the trigger") task_instance = session.scalar( select(TaskInstance).where( - TaskInstance.dag_id == self.task_instance.dag_id, - TaskInstance.task_id == self.task_instance.task_id, - TaskInstance.run_id == self.task_instance.run_id, - TaskInstance.map_index == self.task_instance.map_index, + TaskInstance.dag_id == ti.dag_id, + TaskInstance.task_id == ti.task_id, + TaskInstance.run_id == ti.run_id, + TaskInstance.map_index == ti.map_index, ) ) if task_instance is None: raise AirflowException( "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found", - self.task_instance.dag_id, - self.task_instance.task_id, - self.task_instance.run_id, - self.task_instance.map_index, + ti.dag_id, + ti.task_id, + ti.run_id, + ti.map_index, ) return task_instance async def get_task_state(self): from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + ti = self.task_instance + if ti is None: + raise RuntimeError("task_instance is not set on the trigger") task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)( - dag_id=self.task_instance.dag_id, - task_ids=[self.task_instance.task_id], - run_ids=[self.task_instance.run_id], - map_index=self.task_instance.map_index, + dag_id=ti.dag_id, + task_ids=[ti.task_id], + run_ids=[ti.run_id], + map_index=ti.map_index, ) try: - task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id] + task_state = task_states_response[ti.run_id][ti.task_id] except Exception: raise AirflowException( "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found", - self.task_instance.dag_id, - self.task_instance.task_id, - self.task_instance.run_id, - self.task_instance.map_index, + ti.dag_id, + ti.task_id, + ti.run_id, + ti.map_index, ) return task_state diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py index 5a218c215646f..17eba607d58bf 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py @@ -141,42 +141,48 @@ def get_task_instance(self, *, session: Session) -> TaskInstance: :param session: Sqlalchemy session """ + ti = self.task_instance + if ti is None: + raise RuntimeError("task_instance is not set on the trigger") task_instance = session.scalar( select(TaskInstance).where( - TaskInstance.dag_id == self.task_instance.dag_id, - TaskInstance.task_id == self.task_instance.task_id, - TaskInstance.run_id == self.task_instance.run_id, - TaskInstance.map_index == self.task_instance.map_index, + TaskInstance.dag_id == ti.dag_id, + TaskInstance.task_id == ti.task_id, + TaskInstance.run_id == ti.run_id, + TaskInstance.map_index == ti.map_index, ) ) if task_instance is None: raise AirflowException( "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found", - self.task_instance.dag_id, - self.task_instance.task_id, - self.task_instance.run_id, - self.task_instance.map_index, + ti.dag_id, + ti.task_id, + ti.run_id, + ti.map_index, ) return task_instance async def get_task_state(self): from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + ti = self.task_instance + if ti is None: + raise RuntimeError("task_instance is not set on the trigger") task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)( - dag_id=self.task_instance.dag_id, - task_ids=[self.task_instance.task_id], - run_ids=[self.task_instance.run_id], - map_index=self.task_instance.map_index, + dag_id=ti.dag_id, + task_ids=[ti.task_id], + run_ids=[ti.run_id], + map_index=ti.map_index, ) try: - task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id] + task_state = task_states_response[ti.run_id][ti.task_id] except Exception: raise AirflowException( "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found", - self.task_instance.dag_id, - self.task_instance.task_id, - self.task_instance.run_id, - self.task_instance.map_index, + ti.dag_id, + ti.task_id, + ti.run_id, + ti.map_index, ) return task_state @@ -293,42 +299,48 @@ def get_task_instance(self, *, session: Session) -> TaskInstance: :param session: Sqlalchemy session """ + ti = self.task_instance + if ti is None: + raise RuntimeError("task_instance is not set on the trigger") task_instance = session.scalar( select(TaskInstance).where( - TaskInstance.dag_id == self.task_instance.dag_id, - TaskInstance.task_id == self.task_instance.task_id, - TaskInstance.run_id == self.task_instance.run_id, - TaskInstance.map_index == self.task_instance.map_index, + TaskInstance.dag_id == ti.dag_id, + TaskInstance.task_id == ti.task_id, + TaskInstance.run_id == ti.run_id, + TaskInstance.map_index == ti.map_index, ) ) if task_instance is None: raise RuntimeError( "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found", - self.task_instance.dag_id, - self.task_instance.task_id, - self.task_instance.run_id, - self.task_instance.map_index, + ti.dag_id, + ti.task_id, + ti.run_id, + ti.map_index, ) return task_instance async def get_task_state(self): from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + ti = self.task_instance + if ti is None: + raise RuntimeError("task_instance is not set on the trigger") task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)( - dag_id=self.task_instance.dag_id, - task_ids=[self.task_instance.task_id], - run_ids=[self.task_instance.run_id], - map_index=self.task_instance.map_index, + dag_id=ti.dag_id, + task_ids=[ti.task_id], + run_ids=[ti.run_id], + map_index=ti.map_index, ) try: - task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id] + task_state = task_states_response[ti.run_id][ti.task_id] except Exception: raise RuntimeError( "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found", - self.task_instance.dag_id, - self.task_instance.task_id, - self.task_instance.run_id, - self.task_instance.map_index, + ti.dag_id, + ti.task_id, + ti.run_id, + ti.map_index, ) return task_state @@ -432,42 +444,48 @@ def serialize(self) -> tuple[str, dict[str, Any]]: @provide_session def get_task_instance(self, *, session: Session) -> TaskInstance: + ti = self.task_instance + if ti is None: + raise RuntimeError("task_instance is not set on the trigger") task_instance = session.scalar( select(TaskInstance).where( - TaskInstance.dag_id == self.task_instance.dag_id, - TaskInstance.task_id == self.task_instance.task_id, - TaskInstance.run_id == self.task_instance.run_id, - TaskInstance.map_index == self.task_instance.map_index, + TaskInstance.dag_id == ti.dag_id, + TaskInstance.task_id == ti.task_id, + TaskInstance.run_id == ti.run_id, + TaskInstance.map_index == ti.map_index, ) ) if task_instance is None: raise AirflowException( "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found.", - self.task_instance.dag_id, - self.task_instance.task_id, - self.task_instance.run_id, - self.task_instance.map_index, + ti.dag_id, + ti.task_id, + ti.run_id, + ti.map_index, ) return task_instance async def get_task_state(self): from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + ti = self.task_instance + if ti is None: + raise RuntimeError("task_instance is not set on the trigger") task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)( - dag_id=self.task_instance.dag_id, - task_ids=[self.task_instance.task_id], - run_ids=[self.task_instance.run_id], - map_index=self.task_instance.map_index, + dag_id=ti.dag_id, + task_ids=[ti.task_id], + run_ids=[ti.run_id], + map_index=ti.map_index, ) try: - task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id] + task_state = task_states_response[ti.run_id][ti.task_id] except Exception: raise AirflowException( "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found", - self.task_instance.dag_id, - self.task_instance.task_id, - self.task_instance.run_id, - self.task_instance.map_index, + ti.dag_id, + ti.task_id, + ti.run_id, + ti.map_index, ) return task_state