From 199afd0c34dd443eca0a95a4c572eb52fa93a59a Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Fri, 1 Sep 2023 00:08:41 +0200 Subject: [PATCH 1/4] Move the try outside the loop when this is possible in Google provider --- .../executors/kubernetes_executor.py | 8 +-- .../google/cloud/triggers/bigquery.py | 66 +++++++++---------- .../google/cloud/triggers/bigquery_dts.py | 19 +++--- .../google/cloud/triggers/cloud_batch.py | 14 ++-- .../google/cloud/triggers/cloud_build.py | 11 ++-- .../google/cloud/triggers/cloud_sql.py | 20 +++--- .../google/cloud/triggers/dataflow.py | 11 ++-- .../google/cloud/triggers/datafusion.py | 11 ++-- .../google/cloud/triggers/dataproc.py | 37 +++++------ .../cloud/triggers/kubernetes_engine.py | 21 +++--- .../google/cloud/triggers/mlengine.py | 10 +-- 11 files changed, 107 insertions(+), 121 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py index 90e4927c345c3..6f77cd51e8fb7 100644 --- a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +++ b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py @@ -354,8 +354,8 @@ def sync(self) -> None: self.kube_scheduler.sync() last_resource_version: dict[str, str] = defaultdict(lambda: "0") - while True: - try: + try: + while True: results = self.result_queue.get_nowait() try: key, state, pod_name, namespace, resource_version = results @@ -373,8 +373,8 @@ def sync(self) -> None: self.result_queue.put(results) finally: self.result_queue.task_done() - except Empty: - break + except Empty: + pass from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils import ResourceVersion diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py index edafddf16e02e..1f80479d8b51b 100644 --- a/airflow/providers/google/cloud/triggers/bigquery.py +++ b/airflow/providers/google/cloud/triggers/bigquery.py @@ -75,8 +75,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Gets current job execution status and yields a TriggerEvent.""" """Gets current job execution status and yields a TriggerEvent.""" hook = self._get_async_hook() - while True: - try: + try: + while True: job_status = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id) if job_status == "success": yield TriggerEvent( @@ -95,10 +95,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] "Bigquery job status is %s. Sleeping for %s seconds.", job_status, self.poll_interval ) await asyncio.sleep(self.poll_interval) - except Exception as e: - self.log.exception("Exception occurred while checking for query completion") - yield TriggerEvent({"status": "error", "message": str(e)}) - return + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) def _get_async_hook(self) -> BigQueryAsyncHook: return BigQueryAsyncHook(gcp_conn_id=self.conn_id) @@ -124,8 +123,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Gets current job execution status and yields a TriggerEvent.""" hook = self._get_async_hook() - while True: - try: + try: + while True: # Poll for job execution status job_status = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id) if job_status == "success": @@ -160,10 +159,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] "Bigquery job status is %s. Sleeping for %s seconds.", job_status, self.poll_interval ) await asyncio.sleep(self.poll_interval) - except Exception as e: - self.log.exception("Exception occurred while checking for query completion") - yield TriggerEvent({"status": "error", "message": str(e)}) - return + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) class BigQueryGetDataTrigger(BigQueryInsertJobTrigger): @@ -196,8 +194,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Gets current job execution status and yields a TriggerEvent with response data.""" hook = self._get_async_hook() - while True: - try: + try: + while True: # Poll for job execution status job_status = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id) if job_status == "success": @@ -220,10 +218,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] "Bigquery job status is %s. Sleeping for %s seconds.", job_status, self.poll_interval ) await asyncio.sleep(self.poll_interval) - except Exception as e: - self.log.exception("Exception occurred while checking for query completion") - yield TriggerEvent({"status": "error", "message": str(e)}) - return + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger): @@ -302,8 +299,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Gets current job execution status and yields a TriggerEvent.""" hook = self._get_async_hook() - while True: - try: + try: + while True: first_job_response_from_hook = await hook.get_job_status( job_id=self.first_job_id, project_id=self.project_id ) @@ -365,10 +362,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] ) return - except Exception as e: - self.log.exception("Exception occurred while checking for query completion") - yield TriggerEvent({"status": "error", "message": str(e)}) - return + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger): @@ -430,8 +426,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Gets current job execution status and yields a TriggerEvent.""" hook = self._get_async_hook() - while True: - try: + try: + while True: # Poll for job execution status response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id) if response_from_hook == "success": @@ -448,10 +444,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] else: yield TriggerEvent({"status": "error", "message": response_from_hook, "records": None}) return - except Exception as e: - self.log.exception("Exception occurred while checking for query completion") - yield TriggerEvent({"status": "error", "message": str(e)}) - return + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) class BigQueryTableExistenceTrigger(BaseTrigger): @@ -501,8 +496,8 @@ def _get_async_hook(self) -> BigQueryTableAsyncHook: async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Will run until the table exists in the Google Big Query.""" - while True: - try: + try: + while True: hook = self._get_async_hook() response = await self._table_exists( hook=hook, dataset=self.dataset_id, table_id=self.table_id, project_id=self.project_id @@ -511,10 +506,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] yield TriggerEvent({"status": "success", "message": "success"}) return await asyncio.sleep(self.poll_interval) - except Exception as e: - self.log.exception("Exception occurred while checking for Table existence") - yield TriggerEvent({"status": "error", "message": str(e)}) - return + except Exception as e: + self.log.exception("Exception occurred while checking for Table existence") + yield TriggerEvent({"status": "error", "message": str(e)}) async def _table_exists( self, hook: BigQueryTableAsyncHook, dataset: str, table_id: str, project_id: str diff --git a/airflow/providers/google/cloud/triggers/bigquery_dts.py b/airflow/providers/google/cloud/triggers/bigquery_dts.py index d5a920a762a2c..16d8ff7b34ff2 100644 --- a/airflow/providers/google/cloud/triggers/bigquery_dts.py +++ b/airflow/providers/google/cloud/triggers/bigquery_dts.py @@ -83,8 +83,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: """If the Transfer Run is in a terminal state, then yield TriggerEvent object.""" hook = self._get_async_hook() - while True: - try: + try: + while True: transfer_run: TransferRun = await hook.get_transfer_run( project_id=self.project_id, config_id=self.config_id, @@ -129,14 +129,13 @@ async def run(self) -> AsyncIterator[TriggerEvent]: self.log.info("Job is still working...") self.log.info("Waiting for %s seconds", self.poll_interval) await asyncio.sleep(self.poll_interval) - except Exception as e: - yield TriggerEvent( - { - "status": "failed", - "message": f"Trigger failed with exception: {e}", - } - ) - return + except Exception as e: + yield TriggerEvent( + { + "status": "failed", + "message": f"Trigger failed with exception: {e}", + } + ) def _get_async_hook(self) -> AsyncBiqQueryDataTransferServiceHook: return AsyncBiqQueryDataTransferServiceHook( diff --git a/airflow/providers/google/cloud/triggers/cloud_batch.py b/airflow/providers/google/cloud/triggers/cloud_batch.py index 211e436c95517..3ae6211fd3eb3 100644 --- a/airflow/providers/google/cloud/triggers/cloud_batch.py +++ b/airflow/providers/google/cloud/triggers/cloud_batch.py @@ -92,9 +92,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: """ timeout = self.timeout hook = self._get_async_hook() - while timeout is None or timeout > 0: - - try: + try: + while timeout is None or timeout > 0: job: Job = await hook.get_batch_job(job_name=self.job_name) status: JobStatus.State = job.status.state @@ -134,10 +133,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: if timeout is None or timeout > 0: await asyncio.sleep(self.polling_period_seconds) - except Exception as e: - self.log.exception("Exception occurred while checking for job completion.") - yield TriggerEvent({"status": "error", "message": str(e)}) - return + except Exception as e: + self.log.exception("Exception occurred while checking for job completion.") + yield TriggerEvent({"status": "error", "message": str(e)}) + return self.log.exception(f"Job with name [{self.job_name}] timed out") yield TriggerEvent( @@ -147,7 +146,6 @@ async def run(self) -> AsyncIterator[TriggerEvent]: "message": f"Batch job with name {self.job_name} timed out", } ) - return def _get_async_hook(self) -> CloudBatchAsyncHook: return CloudBatchAsyncHook( diff --git a/airflow/providers/google/cloud/triggers/cloud_build.py b/airflow/providers/google/cloud/triggers/cloud_build.py index e07dc939070bf..dddb9d823acf3 100644 --- a/airflow/providers/google/cloud/triggers/cloud_build.py +++ b/airflow/providers/google/cloud/triggers/cloud_build.py @@ -78,8 +78,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Gets current build execution status and yields a TriggerEvent.""" hook = self._get_async_hook() - while True: - try: + try: + while True: # Poll for job execution status cloud_build_instance = await hook.get_cloud_build( id_=self.id_, @@ -119,10 +119,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] ) return - except Exception as e: - self.log.exception("Exception occurred while checking for Cloud Build completion") - yield TriggerEvent({"status": "error", "message": str(e)}) - return + except Exception as e: + self.log.exception("Exception occurred while checking for Cloud Build completion") + yield TriggerEvent({"status": "error", "message": str(e)}) def _get_async_hook(self) -> CloudBuildAsyncHook: return CloudBuildAsyncHook(gcp_conn_id=self.gcp_conn_id) diff --git a/airflow/providers/google/cloud/triggers/cloud_sql.py b/airflow/providers/google/cloud/triggers/cloud_sql.py index e04ada9277fcc..be1cd739d242c 100644 --- a/airflow/providers/google/cloud/triggers/cloud_sql.py +++ b/airflow/providers/google/cloud/triggers/cloud_sql.py @@ -64,8 +64,8 @@ def serialize(self): ) async def run(self): - while True: - try: + try: + while True: operation = await self.hook.get_operation( project_id=self.project_id, operation_name=self.operation_name ) @@ -93,11 +93,11 @@ async def run(self): self.poke_interval, ) await asyncio.sleep(self.poke_interval) - except Exception as e: - self.log.exception("Exception occurred while checking operation status.") - yield TriggerEvent( - { - "status": "failed", - "message": str(e), - } - ) + except Exception as e: + self.log.exception("Exception occurred while checking operation status.") + yield TriggerEvent( + { + "status": "failed", + "message": str(e), + } + ) diff --git a/airflow/providers/google/cloud/triggers/dataflow.py b/airflow/providers/google/cloud/triggers/dataflow.py index 5dfdf5106a42b..30f42dfdb19ca 100644 --- a/airflow/providers/google/cloud/triggers/dataflow.py +++ b/airflow/providers/google/cloud/triggers/dataflow.py @@ -92,8 +92,8 @@ async def run(self): amount of time stored in self.poll_sleep variable. """ hook = self._get_async_hook() - while True: - try: + try: + while True: status = await hook.get_job_status( project_id=self.project_id, job_id=self.job_id, @@ -129,10 +129,9 @@ async def run(self): self.log.info("Current job status is: %s", status) self.log.info("Sleeping for %s seconds.", self.poll_sleep) await asyncio.sleep(self.poll_sleep) - except Exception as e: - self.log.exception("Exception occurred while checking for job completion.") - yield TriggerEvent({"status": "error", "message": str(e)}) - return + except Exception as e: + self.log.exception("Exception occurred while checking for job completion.") + yield TriggerEvent({"status": "error", "message": str(e)}) def _get_async_hook(self) -> AsyncDataflowHook: return AsyncDataflowHook( diff --git a/airflow/providers/google/cloud/triggers/datafusion.py b/airflow/providers/google/cloud/triggers/datafusion.py index 66ed139f34bf1..06bf5e053eab2 100644 --- a/airflow/providers/google/cloud/triggers/datafusion.py +++ b/airflow/providers/google/cloud/triggers/datafusion.py @@ -83,8 +83,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Gets current pipeline status and yields a TriggerEvent.""" hook = self._get_async_hook() - while True: - try: + try: + while True: # Poll for job execution status response_from_hook = await hook.get_pipeline_status( success_states=self.success_states, @@ -109,10 +109,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] else: yield TriggerEvent({"status": "error", "message": response_from_hook}) return - except Exception as e: - self.log.exception("Exception occurred while checking for pipeline state") - yield TriggerEvent({"status": "error", "message": str(e)}) - return + except Exception as e: + self.log.exception("Exception occurred while checking for pipeline state") + yield TriggerEvent({"status": "error", "message": str(e)}) def _get_async_hook(self) -> DataFusionAsyncHook: return DataFusionAsyncHook( diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 3f94c49965061..17507261a3cc6 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -263,8 +263,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: """Wait until cluster is deleted completely.""" - while self.end_time > time.time(): - try: + try: + while self.end_time > time.time(): cluster = await self.get_async_hook().get_cluster( region=self.region, # type: ignore[arg-type] cluster_name=self.cluster_name, @@ -277,12 +277,12 @@ async def run(self) -> AsyncIterator[TriggerEvent]: self.polling_interval_seconds, ) await asyncio.sleep(self.polling_interval_seconds) - except NotFound: - yield TriggerEvent({"status": "success", "message": ""}) - return - except Exception as e: - yield TriggerEvent({"status": "error", "message": str(e)}) - return + except NotFound: + yield TriggerEvent({"status": "success", "message": ""}) + return + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e)}) + return yield TriggerEvent({"status": "error", "message": "Timeout"}) @@ -312,8 +312,8 @@ def serialize(self): async def run(self) -> AsyncIterator[TriggerEvent]: hook = self.get_async_hook() - while True: - try: + try: + while True: operation = await hook.get_operation(region=self.region, operation_name=self.name) if operation.done: if operation.error.message: @@ -338,12 +338,11 @@ async def run(self) -> AsyncIterator[TriggerEvent]: else: self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds) await asyncio.sleep(self.polling_interval_seconds) - except Exception as e: - self.log.exception("Exception occurred while checking operation status.") - yield TriggerEvent( - { - "status": "failed", - "message": str(e), - } - ) - return + except Exception as e: + self.log.exception("Exception occurred while checking operation status.") + yield TriggerEvent( + { + "status": "failed", + "message": str(e), + } + ) diff --git a/airflow/providers/google/cloud/triggers/kubernetes_engine.py b/airflow/providers/google/cloud/triggers/kubernetes_engine.py index d47538a05f0c1..988ccd5558e9e 100644 --- a/airflow/providers/google/cloud/triggers/kubernetes_engine.py +++ b/airflow/providers/google/cloud/triggers/kubernetes_engine.py @@ -184,8 +184,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Gets operation status and yields corresponding event.""" hook = self._get_hook() - while True: - try: + try: + while True: operation = await hook.get_operation( operation_name=self.operation_name, project_id=self.project_id, @@ -214,15 +214,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] } ) return - except Exception as e: - self.log.exception("Exception occurred while checking operation status") - yield TriggerEvent( - { - "status": "error", - "message": str(e), - } - ) - return + except Exception as e: + self.log.exception("Exception occurred while checking operation status") + yield TriggerEvent( + { + "status": "error", + "message": str(e), + } + ) def _get_hook(self) -> GKEAsyncHook: if self._hook is None: diff --git a/airflow/providers/google/cloud/triggers/mlengine.py b/airflow/providers/google/cloud/triggers/mlengine.py index 76c542a5bf4d4..d5c6cd60ca3dc 100644 --- a/airflow/providers/google/cloud/triggers/mlengine.py +++ b/airflow/providers/google/cloud/triggers/mlengine.py @@ -91,8 +91,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Gets current job execution status and yields a TriggerEvent.""" hook = self._get_async_hook() - while True: - try: + try: + while True: # Poll for job execution status response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id) if response_from_hook == "success": @@ -110,9 +110,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] else: yield TriggerEvent({"status": "error", "message": response_from_hook}) - except Exception as e: - self.log.exception("Exception occurred while checking for query completion") - yield TriggerEvent({"status": "error", "message": str(e)}) + except Exception as e: + self.log.exception("Exception occurred while checking for query completion") + yield TriggerEvent({"status": "error", "message": str(e)}) def _get_async_hook(self) -> MLEngineAsyncHook: return MLEngineAsyncHook( From 3e990a708a7fb84bc2d9e7d6ddd1ddb76688da0e Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sat, 2 Sep 2023 00:09:07 +0200 Subject: [PATCH 2/4] Update airflow/providers/google/cloud/triggers/dataproc.py Co-authored-by: Tzu-ping Chung --- airflow/providers/google/cloud/triggers/dataproc.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 17507261a3cc6..adb34cae03d89 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -279,11 +279,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: await asyncio.sleep(self.polling_interval_seconds) except NotFound: yield TriggerEvent({"status": "success", "message": ""}) - return except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)}) - return - yield TriggerEvent({"status": "error", "message": "Timeout"}) + else: + yield TriggerEvent({"status": "error", "message": "Timeout"}) class DataprocWorkflowTrigger(DataprocBaseTrigger): From a3119b831a18f6450bfdb4ee572827833d93bcd9 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sat, 2 Sep 2023 00:11:31 +0200 Subject: [PATCH 3/4] Use supress instead of except pass --- .../cncf/kubernetes/executors/kubernetes_executor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py index 6f77cd51e8fb7..4bdcf2bd69d4f 100644 --- a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +++ b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py @@ -354,7 +354,7 @@ def sync(self) -> None: self.kube_scheduler.sync() last_resource_version: dict[str, str] = defaultdict(lambda: "0") - try: + with suppress(Empty): while True: results = self.result_queue.get_nowait() try: @@ -373,8 +373,6 @@ def sync(self) -> None: self.result_queue.put(results) finally: self.result_queue.task_done() - except Empty: - pass from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils import ResourceVersion From fa7b0c70891a4f6b7ba22b3dd3125c876772f25d Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sat, 2 Sep 2023 00:22:59 +0200 Subject: [PATCH 4/4] revert change in kubernetes provider --- .../cncf/kubernetes/executors/kubernetes_executor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py index 4bdcf2bd69d4f..90e4927c345c3 100644 --- a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +++ b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py @@ -354,8 +354,8 @@ def sync(self) -> None: self.kube_scheduler.sync() last_resource_version: dict[str, str] = defaultdict(lambda: "0") - with suppress(Empty): - while True: + while True: + try: results = self.result_queue.get_nowait() try: key, state, pod_name, namespace, resource_version = results @@ -373,6 +373,8 @@ def sync(self) -> None: self.result_queue.put(results) finally: self.result_queue.task_done() + except Empty: + break from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_utils import ResourceVersion