diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_run.py b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_run.py index ca67373ec64dd..f2f3ba417b1c0 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_run.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_run.py @@ -19,7 +19,7 @@ import itertools from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from google.cloud.run_v2 import ( CreateJobRequest, @@ -67,16 +67,21 @@ class CloudRunHook(GoogleBaseHook): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account. + :param transport: Optional. The transport to use for API requests. Can be 'rest' or 'grpc'. + If set to None, a transport is chosen automatically. Use 'rest' if gRPC is not available + or fails in your environment (e.g., Docker containers with certain network configurations). """ def __init__( self, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, + transport: Literal["rest", "grpc"] | None = None, **kwargs, ) -> None: super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, **kwargs) self._client: JobsClient | None = None + self.transport = transport def get_conn(self): """ @@ -85,7 +90,12 @@ def get_conn(self): :return: Cloud Run Jobs client object. """ if self._client is None: - self._client = JobsClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) + client_kwargs = { + "credentials": self.get_credentials(), + "client_info": CLIENT_INFO, + "transport": self.transport, + } + self._client = JobsClient(**client_kwargs) return self._client @GoogleBaseHook.fallback_to_default_project_id @@ -176,6 +186,9 @@ class CloudRunAsyncHook(GoogleBaseAsyncHook): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account. + :param transport: Optional. The transport to use for API requests. Can be 'rest' or 'grpc'. + If set to None, a transport is chosen automatically. Use 'rest' if gRPC is not available + or fails in your environment (e.g., Docker containers with certain network configurations). """ sync_hook_class = CloudRunHook @@ -184,15 +197,24 @@ def __init__( self, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, + transport: Literal["rest", "grpc"] | None = None, **kwargs, ): self._client: JobsAsyncClient | None = None - super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, **kwargs) + self.transport = transport + super().__init__( + gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, transport=transport, **kwargs + ) async def get_conn(self): if self._client is None: sync_hook = await self.get_sync_hook() - self._client = JobsAsyncClient(credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO) + client_kwargs = { + "credentials": sync_hook.get_credentials(), + "client_info": CLIENT_INFO, + "transport": self.transport, + } + self._client = JobsAsyncClient(**client_kwargs) return self._client diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py index 7e7b5faf1b4a7..5c12dd4d7dafc 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py @@ -18,7 +18,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import google.cloud.exceptions from google.api_core.exceptions import AlreadyExists @@ -263,6 +263,9 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). :param deferrable: Run the operator in deferrable mode. + :param transport: Optional. The transport to use for API requests. Can be 'rest' or 'grpc'. + If set to None, a transport is chosen automatically. Use 'rest' if gRPC is not available + or fails in your environment (e.g., Docker containers with certain network configurations). """ operator_extra_links = (CloudRunJobLoggingLink(),) @@ -275,6 +278,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator): "overrides", "polling_period_seconds", "timeout_seconds", + "transport", ) def __init__( @@ -288,6 +292,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + transport: Literal["rest", "grpc"] | None = None, **kwargs, ): super().__init__(**kwargs) @@ -300,11 +305,14 @@ def __init__( self.polling_period_seconds = polling_period_seconds self.timeout_seconds = timeout_seconds self.deferrable = deferrable + self.transport = transport self.operation: operation.Operation | None = None def execute(self, context: Context): hook: CloudRunHook = CloudRunHook( - gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + transport=self.transport, ) self.operation = hook.execute_job( region=self.region, project_id=self.project_id, job_name=self.job_name, overrides=self.overrides @@ -333,6 +341,7 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, polling_period_seconds=self.polling_period_seconds, + transport=self.transport, ), method_name="execute_complete", ) @@ -350,7 +359,11 @@ def execute_complete(self, context: Context, event: dict): f"Operation failed with error code [{error_code}] and error message [{error_message}]" ) - hook: CloudRunHook = CloudRunHook(self.gcp_conn_id, self.impersonation_chain) + hook: CloudRunHook = CloudRunHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + transport=self.transport, + ) job = hook.get_job(job_name=event["job_name"], region=self.region, project_id=self.project_id) return Job.to_dict(job) diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_run.py b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_run.py index b5547f45bac7e..8261edd416a3e 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_run.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_run.py @@ -19,7 +19,7 @@ import asyncio from collections.abc import AsyncIterator, Sequence from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from airflow.providers.common.compat.sdk import AirflowException from airflow.providers.google.cloud.hooks.cloud_run import CloudRunAsyncHook @@ -59,6 +59,9 @@ class CloudRunJobFinishedTrigger(BaseTrigger): account from the list granting this role to the originating account (templated). :param poll_sleep: Polling period in seconds to check for the status. :timeout: The time to wait before failing the operation. + :param transport: Optional. The transport to use for API requests. Can be 'rest' or 'grpc'. + Defaults to 'grpc'. Use 'rest' if gRPC is not available or fails in your environment + (e.g., Docker containers with certain network configurations). """ def __init__( @@ -71,6 +74,7 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, polling_period_seconds: float = 10, timeout: float | None = None, + transport: Literal["rest", "grpc"] | None = None, ): super().__init__() self.project_id = project_id @@ -81,6 +85,7 @@ def __init__( self.polling_period_seconds = polling_period_seconds self.timeout = timeout self.impersonation_chain = impersonation_chain + self.transport = transport def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize class arguments and classpath.""" @@ -95,6 +100,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "polling_period_seconds": self.polling_period_seconds, "timeout": self.timeout, "impersonation_chain": self.impersonation_chain, + "transport": self.transport, }, ) @@ -143,4 +149,5 @@ def _get_async_hook(self) -> CloudRunAsyncHook: return CloudRunAsyncHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, + transport=self.transport or "grpc", ) diff --git a/providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py b/providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py index bccea8c3e34f7..4a8150459c4a0 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py @@ -259,6 +259,22 @@ def test_delete_job(self, mock_batch_service_client, cloud_run_hook): cloud_run_hook.delete_job(job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID) cloud_run_hook._client.delete_job.assert_called_once_with(delete_request) + @mock.patch( + "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__", + new=mock_base_gcp_hook_default_project_id, + ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient") + @pytest.mark.parametrize(("transport", "expected_transport"), [("rest", "rest"), (None, None)]) + def test_get_conn_with_transport(self, mock_jobs_client, transport, expected_transport): + """Test that transport parameter is passed to JobsClient.""" + hook = CloudRunHook(transport=transport) + hook.get_credentials = self.dummy_get_credentials + hook.get_conn() + + mock_jobs_client.assert_called_once() + call_kwargs = mock_jobs_client.call_args[1] + assert call_kwargs["transport"] == expected_transport + def _mock_pager(self, number_of_jobs): mock_pager = [] for i in range(number_of_jobs): diff --git a/providers/google/tests/unit/google/cloud/operators/test_cloud_run.py b/providers/google/tests/unit/google/cloud/operators/test_cloud_run.py index d4431877121e2..3c389713b1ab4 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_cloud_run.py +++ b/providers/google/tests/unit/google/cloud/operators/test_cloud_run.py @@ -102,6 +102,28 @@ def test_template_fields(self): assert "overrides" in operator.template_fields assert "polling_period_seconds" in operator.template_fields assert "timeout_seconds" in operator.template_fields + assert "transport" in operator.template_fields + + @mock.patch(CLOUD_RUN_HOOK_PATH) + def test_execute_with_transport(self, hook_mock): + """Test that transport parameter is passed to CloudRunHook.""" + hook_mock.return_value.get_job.return_value = JOB + hook_mock.return_value.execute_job.return_value = self._mock_operation(3, 3, 0) + + operator = CloudRunExecuteJobOperator( + task_id=TASK_ID, + project_id=PROJECT_ID, + region=REGION, + job_name=JOB_NAME, + transport="rest", + ) + + operator.execute(context=mock.MagicMock()) + + # Verify that CloudRunHook was instantiated with transport parameter + hook_mock.assert_called_once() + call_kwargs = hook_mock.call_args[1] + assert call_kwargs["transport"] == "rest" @mock.patch(CLOUD_RUN_HOOK_PATH) def test_execute_success(self, hook_mock): diff --git a/providers/google/tests/unit/google/cloud/triggers/test_cloud_run.py b/providers/google/tests/unit/google/cloud/triggers/test_cloud_run.py index 7a526d590c265..3902a17885e0f 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_cloud_run.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_cloud_run.py @@ -49,6 +49,7 @@ def trigger(): polling_period_seconds=POLL_SLEEP, timeout=TIMEOUT, impersonation_chain=IMPERSONATION_CHAIN, + transport=None, ) @@ -65,6 +66,7 @@ def test_serialization(self, trigger): "polling_period_seconds": POLL_SLEEP, "timeout": TIMEOUT, "impersonation_chain": IMPERSONATION_CHAIN, + "transport": None, } @pytest.mark.asyncio