From 326cf44eeac176f490eaf20feec40edeee3e611f Mon Sep 17 00:00:00 2001 From: Ulada Zakharava Date: Tue, 30 Sep 2025 14:26:39 +0000 Subject: [PATCH] Fix dataflow java system test + link --- .../providers/apache/beam/operators/beam.py | 14 +++- .../providers/apache/beam/triggers/beam.py | 10 ++- .../unit/apache/beam/operators/test_beam.py | 14 +++- .../unit/apache/beam/triggers/test_beam.py | 77 +++++++++++++++---- .../dataflow/example_dataflow_native_java.py | 4 +- 5 files changed, 93 insertions(+), 26 deletions(-) diff --git a/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py b/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py index 8d4d893c9196a..7713bbf067b1c 100644 --- a/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py +++ b/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py @@ -458,7 +458,12 @@ def execute_on_dataflow(self, context: Context): ) location = self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION - DataflowJobLink.persist(context=context, region=location) + DataflowJobLink.persist( + context=context, + region=self.dataflow_config.location, + job_id=self.dataflow_job_id, + project_id=self.dataflow_config.project_id, + ) if self.deferrable: trigger_args = { @@ -648,7 +653,12 @@ def execute_on_dataflow(self, context: Context): is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback, ) if self.dataflow_job_name and self.dataflow_config.location: - DataflowJobLink.persist(context=context) + DataflowJobLink.persist( + context=context, + region=self.dataflow_config.location, + job_id=self.dataflow_job_id, + project_id=self.dataflow_config.project_id, + ) if self.deferrable: trigger_args = { "job_id": self.dataflow_job_id, diff --git a/providers/apache/beam/src/airflow/providers/apache/beam/triggers/beam.py b/providers/apache/beam/src/airflow/providers/apache/beam/triggers/beam.py index 5778a12666614..6b2464eb37fb2 100644 --- a/providers/apache/beam/src/airflow/providers/apache/beam/triggers/beam.py +++ b/providers/apache/beam/src/airflow/providers/apache/beam/triggers/beam.py @@ -39,23 +39,25 @@ def file_has_gcs_path(file_path: str): @staticmethod async def provide_gcs_tempfile(gcs_file, gcp_conn_id): try: - from airflow.providers.google.cloud.hooks.gcs import GCSHook + from airflow.providers.google.cloud.hooks.gcs import GCSAsyncHook except ImportError: from airflow.exceptions import AirflowOptionalProviderFeatureException raise AirflowOptionalProviderFeatureException( - "Failed to import GCSHook. To use the GCSHook functionality, please install the " + "Failed to import GCSAsyncHook. To use the GCSAsyncHook functionality, please install the " "apache-airflow-google-provider." ) - gcs_hook = GCSHook(gcp_conn_id=gcp_conn_id) + async_gcs_hook = GCSAsyncHook(gcp_conn_id=gcp_conn_id) + sync_gcs_hook = await async_gcs_hook.get_sync_hook() + loop = asyncio.get_running_loop() # Running synchronous `enter_context()` method in a separate # thread using the default executor `None`. The `run_in_executor()` function returns the # file object, which is created using gcs function `provide_file()`, asynchronously. # This means we can perform asynchronous operations with this file. - create_tmp_file_call = gcs_hook.provide_file(object_url=gcs_file) + create_tmp_file_call = sync_gcs_hook.provide_file(object_url=gcs_file) tmp_gcs_file: IO[str] = await loop.run_in_executor( None, contextlib.ExitStack().enter_context, diff --git a/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py b/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py index 626800de7cac6..0727d472f5ef6 100644 --- a/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py +++ b/providers/apache/beam/tests/unit/apache/beam/operators/test_beam.py @@ -247,7 +247,12 @@ def test_exec_dataflow_runner( } gcs_provide_file.assert_any_call(object_url=PY_FILE) gcs_provide_file.assert_any_call(object_url=REQURIEMENTS_FILE) - persist_link_mock.assert_called_once_with(context={}, region="us-central1") + persist_link_mock.assert_called_once_with( + context={}, + region="us-central1", + job_id=None, + project_id=dataflow_hook_mock.return_value.project_id, + ) beam_hook_mock.return_value.start_python_pipeline.assert_called_once_with( variables=expected_options, py_file=gcs_provide_file.return_value.__enter__.return_value.name, @@ -468,7 +473,12 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock "output": "gs://test/output", "impersonateServiceAccount": TEST_IMPERSONATION_ACCOUNT, } - persist_link_mock.assert_called_once_with(context={}) + persist_link_mock.assert_called_once_with( + context={}, + region="us-central1", + job_id=None, + project_id=dataflow_hook_mock.return_value.project_id, + ) beam_hook_mock.return_value.start_java_pipeline.assert_called_once_with( variables=expected_options, jar=gcs_provide_file.return_value.__enter__.return_value.name, diff --git a/providers/apache/beam/tests/unit/apache/beam/triggers/test_beam.py b/providers/apache/beam/tests/unit/apache/beam/triggers/test_beam.py index aed634c223be7..543b1d7fbe590 100644 --- a/providers/apache/beam/tests/unit/apache/beam/triggers/test_beam.py +++ b/providers/apache/beam/tests/unit/apache/beam/triggers/test_beam.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import asyncio from unittest import mock import pytest @@ -134,17 +135,41 @@ async def test_beam_trigger_exception_should_execute_successfully( assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual @pytest.mark.asyncio - async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, python_trigger): + async def test_beam_trigger_gcs_provide_file_should_execute_successfully( + self, python_trigger, monkeypatch + ): """ - Test that BeamPythonPipelineTrigger downloads GCS provide file correct. + Test that BeamPythonPipelineTrigger downloads GCS provide file correctly with GCSAsyncHook. """ + TEST_GCS_PY_FILE = "gs://bucket/path/file.py" python_trigger.py_file = TEST_GCS_PY_FILE - with mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook") as mock_gcs_hook: - mock_gcs_hook.return_value.provide_file.return_value = "mocked_temp_file" - generator = python_trigger.run() - await generator.asend(None) - mock_gcs_hook.assert_called_once_with(gcp_conn_id=python_trigger.gcp_conn_id) - mock_gcs_hook.return_value.provide_file.assert_called_once_with(object_url=TEST_GCS_PY_FILE) + + with mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSAsyncHook") as MockAsyncHook: + async_hook_instance = MockAsyncHook.return_value + + class DummyCM: + def __enter__(self): + return "mocked_temp_file" + + def __exit__(self, exc_type, exc, tb): + return False + + sync_hook = mock.Mock(name="SyncGCSHook") + sync_hook.provide_file.return_value = DummyCM() + + async_hook_instance.get_sync_hook = mock.AsyncMock(return_value=sync_hook) + + fake_loop = mock.Mock() + fake_loop.run_in_executor = mock.AsyncMock(return_value="mocked_temp_file") + monkeypatch.setattr(asyncio, "get_running_loop", lambda: fake_loop) + + gen = python_trigger.run() + await gen.asend(None) + + MockAsyncHook.assert_called_once_with(gcp_conn_id=python_trigger.gcp_conn_id) + async_hook_instance.get_sync_hook.assert_awaited_once() + sync_hook.provide_file.assert_called_once_with(object_url=TEST_GCS_PY_FILE) + fake_loop.run_in_executor.assert_awaited_once() class TestBeamJavaPipelineTrigger: @@ -211,15 +236,35 @@ async def test_beam_trigger_exception_should_execute_successfully( assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual @pytest.mark.asyncio - async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, java_trigger): + async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, java_trigger, monkeypatch): """ - Test that BeamJavaPipelineTrigger downloads GCS provide file correct. + Test that BeamJavaPipelineTrigger downloads GCS provide file correctly with GCSAsyncHook. """ java_trigger.jar = TEST_GCS_JAR_FILE - with mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook") as mock_gcs_hook: - mock_gcs_hook.return_value.provide_file.return_value = "mocked_temp_file" - generator = java_trigger.run() - await generator.asend(None) - mock_gcs_hook.assert_called_once_with(gcp_conn_id=java_trigger.gcp_conn_id) - mock_gcs_hook.return_value.provide_file.assert_called_once_with(object_url=TEST_GCS_JAR_FILE) + with mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSAsyncHook") as MockAsyncHook: + async_hook_instance = MockAsyncHook.return_value + + class DummyCM: + def __enter__(self): + return "mocked_temp_file" + + def __exit__(self, exc_type, exc, tb): + return False + + sync_hook = mock.Mock(name="SyncGCSHook") + sync_hook.provide_file.return_value = DummyCM() + + async_hook_instance.get_sync_hook = mock.AsyncMock(return_value=sync_hook) + + fake_loop = mock.Mock() + fake_loop.run_in_executor = mock.AsyncMock(return_value="mocked_temp_file") + monkeypatch.setattr(asyncio, "get_running_loop", lambda: fake_loop) + + gen = java_trigger.run() + await gen.asend(None) + + MockAsyncHook.assert_called_once_with(gcp_conn_id=java_trigger.gcp_conn_id) + async_hook_instance.get_sync_hook.assert_awaited_once() + sync_hook.provide_file.assert_called_once_with(object_url=TEST_GCS_JAR_FILE) + fake_loop.run_in_executor.assert_awaited_once() diff --git a/providers/google/tests/system/google/cloud/dataflow/example_dataflow_native_java.py b/providers/google/tests/system/google/cloud/dataflow/example_dataflow_native_java.py index 550b0381e6a62..ed74ede79d472 100644 --- a/providers/google/tests/system/google/cloud/dataflow/example_dataflow_native_java.py +++ b/providers/google/tests/system/google/cloud/dataflow/example_dataflow_native_java.py @@ -87,7 +87,7 @@ # [START howto_operator_start_java_job_local_jar] start_java_job_direct = BeamRunJavaPipelineOperator( task_id="start_java_job_direct", - jar=LOCAL_JAR, + jar=GCS_JAR, pipeline_options={ "output": GCS_OUTPUT, }, @@ -102,7 +102,7 @@ start_java_job_direct_deferrable = BeamRunJavaPipelineOperator( task_id="start_java_job_direct_deferrable", - jar=GCS_JAR, + jar=LOCAL_JAR, pipeline_options={ "output": GCS_OUTPUT, },