diff --git a/providers/google/src/airflow/providers/google/cloud/operators/gen_ai.py b/providers/google/src/airflow/providers/google/cloud/operators/gen_ai.py index 164ec485710a0..79381044a2f56 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/gen_ai.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/gen_ai.py @@ -646,7 +646,7 @@ def execute(self, context: Context): raise AirflowException("Job with name %s not found", self.job_name) context["ti"].xcom_push(key="job_status", value=job.state) - return dict(job) + return job.model_dump(mode="json") class GenAIGeminiListBatchJobsOperator(GoogleCloudBaseOperator): diff --git a/providers/google/tests/unit/google/cloud/operators/test_gen_ai.py b/providers/google/tests/unit/google/cloud/operators/test_gen_ai.py index f91d154d9f177..c532f8b3e5c14 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_gen_ai.py +++ b/providers/google/tests/unit/google/cloud/operators/test_gen_ai.py @@ -107,7 +107,7 @@ TEST_BATCH_JOB_NAME = "test-name" TEST_FILE_NAME = "test-file" TEST_FILE_PATH = "test/path/to/file" -TEST_CREATE_BATCH_JOB_RESPONSE = { +TEST_BATCH_JOB_RESPONSE = { "src": None, "dest": "test-batch-job-destination", "name": "test-name", @@ -317,7 +317,7 @@ def test_execute(self, mock_hook): @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook")) def test_execute_return_value(self, mock_hook): - expected_return = TEST_CREATE_BATCH_JOB_RESPONSE + expected_return = TEST_BATCH_JOB_RESPONSE mock_job = mock.MagicMock() mock_job.model_dump.return_value = expected_return @@ -343,7 +343,7 @@ def test_execute_return_value(self, mock_hook): @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook")) def test_execute_complete_return_value(self, mock_hook): - expected_return = TEST_CREATE_BATCH_JOB_RESPONSE + expected_return = TEST_BATCH_JOB_RESPONSE event = {"status": "success", "job_name": "test-name"} @@ -391,6 +391,29 @@ def test_execute(self, mock_hook): job_name=TEST_BATCH_JOB_NAME, ) + @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook")) + def test_execute_return_value(self, mock_hook): + expected_return = TEST_BATCH_JOB_RESPONSE + + mock_job = mock.MagicMock() + mock_job.model_dump.return_value = expected_return + mock_hook.return_value.get_batch_job.return_value = mock_job + + op = GenAIGeminiGetBatchJobOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + job_name=TEST_BATCH_JOB_NAME, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + gemini_api_key=TEST_GEMINI_API_KEY, + ) + + result = op.execute(context={"ti": mock.MagicMock()}) + + assert result == expected_return + mock_job.model_dump.assert_called_once_with(mode="json") + class TestGenAIGeminiListBatchJobsOperator: @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))