diff --git a/providers/google/docs/operators/cloud/gen_ai.rst b/providers/google/docs/operators/cloud/gen_ai.rst index 1cbb4534aa1b5..1d115aa763ffd 100644 --- a/providers/google/docs/operators/cloud/gen_ai.rst +++ b/providers/google/docs/operators/cloud/gen_ai.rst @@ -142,6 +142,14 @@ If you use file take a look at this example: :start-after: [START how_to_cloud_gen_ai_batch_api_create_batch_job_with_file_task] :end-before: [END how_to_cloud_gen_ai_batch_api_create_batch_job_with_file_task] +Operator could run in deferrable mode: + +.. exampleinclude:: /../../google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_gen_ai_batch_api_create_batch_job_with_inlined_requests_deferrable_task] + :end-before: [END how_to_cloud_gen_ai_batch_api_create_batch_job_with_inlined_requests_deferrable_task] + Get batch job """"""""""""" @@ -220,6 +228,14 @@ If you use file take a look at this example: :start-after: [START how_to_cloud_gen_ai_batch_api_create_embeddings_with_file_task] :end-before: [END how_to_cloud_gen_ai_batch_api_create_embeddings_with_file_task] +Operator could run in deferrable mode: + +.. exampleinclude:: /../../google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_gen_ai_batch_api_create_embeddings_with_file_deferrable_task] + :end-before: [END how_to_cloud_gen_ai_batch_api_create_embeddings_with_file_deferrable_task] + Interacting with Gemini Files API ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/providers/google/provider.yaml b/providers/google/provider.yaml index 357ea767591e1..2da3ff99cd123 100644 --- a/providers/google/provider.yaml +++ b/providers/google/provider.yaml @@ -958,6 +958,9 @@ triggers: - integration-name: Google Vertex AI python-modules: - airflow.providers.google.cloud.triggers.vertex_ai + - integration-name: Google Cloud Generative AI + python-modules: + - airflow.providers.google.cloud.triggers.gen_ai transfers: - source-integration-name: Presto diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/gen_ai.py b/providers/google/src/airflow/providers/google/cloud/hooks/gen_ai.py index 827470cbe2e94..c71f314738222 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/gen_ai.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/gen_ai.py @@ -19,12 +19,17 @@ from __future__ import annotations +import enum import time from typing import TYPE_CHECKING, Any from google import genai -from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook +from airflow.providers.google.common.hooks.base_google import ( + PROVIDE_PROJECT_ID, + GoogleBaseAsyncHook, + GoogleBaseHook, +) if TYPE_CHECKING: from google.genai.pagers import Pager @@ -49,6 +54,17 @@ ) +class BatchJobStatus(enum.Enum): + """Possible states of batch job in Gemini Batch API.""" + + SUCCEEDED = "JOB_STATE_SUCCEEDED" + PENDING = "JOB_STATE_PENDING" + FAILED = "JOB_STATE_FAILED" + RUNNING = "JOB_STATE_RUNNING" + CANCELLED = "JOB_STATE_CANCELLED" + EXPIRED = "JOB_STATE_EXPIRED" + + class GenAIGenerativeModelHook(GoogleBaseHook): """Class for Google Cloud Generative AI Vertex AI hook.""" @@ -382,3 +398,79 @@ def delete_file(self, file_name: str) -> DeleteFileResponse: client = self.get_genai_client() resp = client.files.delete(name=file_name) return resp + + +class GenAIGeminiAPIAsyncHook(GoogleBaseAsyncHook): + """Class for Google Cloud Generative AI Gemini Developer Async API hook.""" + + sync_hook_class = GenAIGeminiAPIHook + + def __init__(self, gemini_api_key: str, **kwargs) -> None: + super().__init__(**kwargs) + self.gemini_api_key = gemini_api_key + + async def get_async_client(self): + return genai.Client( + api_key=self.gemini_api_key, + vertexai=False, + ).aio + + async def create_batch_job( + self, + model: str, + source: list | str, + create_batch_job_config: CreateBatchJobConfig | dict | None = None, + ) -> BatchJob: + """ + Create batch job asynchronously using Gemini Batch API to process large-scale, non-urgent tasks. + + :param model: Required. Gemini model name to process requests. + :param source: Required. Requests that will be sent to chosen model. + Can be in format of Inline requests or file name. + :param create_batch_job_config: Optional. Configuration parameters for batch job. + """ + async_client = await self.get_async_client() + async_job = await async_client.batches.create(model=model, src=source, config=create_batch_job_config) + return async_job + + async def get_batch_job( + self, + job_name: str, + ) -> BatchJob: + """ + Get batch job using Gemini Batch API asynchronously. + + :param job_name: Required. Batch job name. + """ + async_client = await self.get_async_client() + resp = await async_client.batches.get(name=job_name) + return resp + + async def create_embeddings_batch_job( + self, + model: str, + source: dict | str, + create_embeddings_config: CreateBatchJobConfig | dict | None = None, + ) -> BatchJob: + """ + Create batch job for embeddings asynchronously using Gemini Batch API to process large-scale, non-urgent tasks. + + :param model: Required. Gemini model name to process requests. + :param source: Required. Requests that will be sent to chosen model. + Can be in format of Inline requests or file name. + :param create_embeddings_config: Optional. Configuration parameters for embeddings batch job. + """ + async_client = await self.get_async_client() + input_type = "inlined_requests" + + if isinstance(source, str): + input_type = "file_name" + + self.log.info("Using %s to create embeddings", input_type) + + resp = await async_client.batches.create_embeddings( + model=model, + src={input_type: source}, + config=create_embeddings_config, + ) + return resp diff --git a/providers/google/src/airflow/providers/google/cloud/operators/gen_ai.py b/providers/google/src/airflow/providers/google/cloud/operators/gen_ai.py index 94e3688ff0e26..faec87f498f6f 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/gen_ai.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/gen_ai.py @@ -19,20 +19,26 @@ from __future__ import annotations -import enum import os.path import time from collections.abc import Sequence from typing import TYPE_CHECKING, Any from google.genai.errors import ClientError +from google.genai.types import BatchJob +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.gen_ai import ( + BatchJobStatus, GenAIGeminiAPIHook, GenAIGenerativeModelHook, ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator +from airflow.providers.google.cloud.triggers.gen_ai import ( + GenAIGeminiCreateBatchJobTrigger, + GenAIGeminiCreateEmbeddingsBatchJobTrigger, +) if TYPE_CHECKING: from google.genai.types import ( @@ -398,17 +404,6 @@ def execute(self, context: Context): return cached_content_name -class BatchJobStatus(enum.Enum): - """Possible states of batch job in Gemini Batch API.""" - - SUCCEEDED = "JOB_STATE_SUCCEEDED" - PENDING = "JOB_STATE_PENDING" - FAILED = "JOB_STATE_FAILED" - RUNNING = "JOB_STATE_RUNNING" - CANCELLED = "JOB_STATE_CANCELLED" - EXPIRED = "JOB_STATE_EXPIRED" - - class GenAIGeminiCreateBatchJobOperator(GoogleCloudBaseOperator): """ Create Batch job using Gemini Batch API. Use to generate model response for several requests. @@ -433,6 +428,7 @@ class GenAIGeminiCreateBatchJobOperator(GoogleCloudBaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: Optional. Run operator in the deferrable mode. """ template_fields = ( @@ -460,6 +456,7 @@ def __init__( retrieve_result: bool = False, wait_until_complete: bool = False, polling_interval: int = 30, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -476,9 +473,12 @@ def __init__( self.wait_until_complete = wait_until_complete self.polling_interval = polling_interval self.results_folder = results_folder + self.deferrable = deferrable - if self.retrieve_result and not self.wait_until_complete: - raise AirflowException("Retrieving results is possible only if wait_until_complete set to True") + if self.retrieve_result and not (self.wait_until_complete or self.deferrable): + raise AirflowException( + "Retrieving results is possible only if wait_until_complete set to True or in deferrable mode" + ) if self.results_folder and not isinstance(self.input_source, str): raise AirflowException("results_folder works only when input_source is file name") if self.results_folder and not os.path.exists(os.path.abspath(self.results_folder)): @@ -533,13 +533,33 @@ def _prepare_results_for_xcom(self, job): return results - def execute(self, context: Context): - self.hook = GenAIGeminiAPIHook( + @property + def hook(self): + return GenAIGeminiAPIHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, gemini_api_key=self.gemini_api_key, ) + def execute(self, context: Context): + if self.deferrable: + self.defer( + trigger=GenAIGeminiCreateBatchJobTrigger( + project_id=self.project_id, + location=self.location, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + model=self.model, + input_source=self.input_source, + create_batch_job_config=self.create_batch_job_config, + gemini_api_key=self.gemini_api_key, + retrieve_result=self.retrieve_result, + polling_interval=self.polling_interval, + results_folder=self.results_folder, + ), + method_name="execute_complete", + ) + try: job = self.hook.create_batch_job( model=self.model, @@ -560,6 +580,16 @@ def execute(self, context: Context): return dict(job) + def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any]: + if event["status"] == "error": + self.log.info("status: %s, msg: %s", event["status"], event["message"]) + raise AirflowException(event["message"]) + job = self.hook.get_batch_job(event["job_name"]) + if self.retrieve_result and job.error is None: + job_results = self._prepare_results_for_xcom(job) + context["ti"].xcom_push(key="job_results", value=job_results) + return dict(job) + class GenAIGeminiGetBatchJobOperator(GoogleCloudBaseOperator): """ @@ -841,6 +871,7 @@ class GenAIGeminiCreateEmbeddingsBatchJobOperator(GoogleCloudBaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: Optional. Run operator in the deferrable mode. """ template_fields = ( @@ -868,6 +899,7 @@ def __init__( wait_until_complete: bool = False, retrieve_result: bool = False, polling_interval: int = 30, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -884,9 +916,12 @@ def __init__( self.retrieve_result = retrieve_result self.polling_interval = polling_interval self.results_folder = results_folder + self.deferrable = deferrable - if self.retrieve_result and not self.wait_until_complete: - raise AirflowException("Retrieving results is possible only if wait_until_complete set to True") + if self.retrieve_result and not (self.wait_until_complete or self.deferrable): + raise AirflowException( + "Retrieving results is possible only if wait_until_complete set to True or in deferrable mode" + ) if self.results_folder and not isinstance(self.input_source, str): raise AirflowException("results_folder works only when input_source is file name") if self.results_folder and not os.path.exists(os.path.abspath(self.results_folder)): @@ -941,13 +976,33 @@ def _prepare_results_for_xcom(self, job): return results - def execute(self, context: Context): - self.hook = GenAIGeminiAPIHook( + @property + def hook(self): + return GenAIGeminiAPIHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, gemini_api_key=self.gemini_api_key, ) + def execute(self, context: Context): + if self.deferrable: + self.defer( + trigger=GenAIGeminiCreateEmbeddingsBatchJobTrigger( + project_id=self.project_id, + location=self.location, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + model=self.model, + input_source=self.input_source, + create_embeddings_config=self.create_embeddings_config, + gemini_api_key=self.gemini_api_key, + retrieve_result=self.retrieve_result, + polling_interval=self.polling_interval, + results_folder=self.results_folder, + ), + method_name="execute_complete", + ) + try: embeddings_job = self.hook.create_embeddings( model=self.model, @@ -968,10 +1023,19 @@ def execute(self, context: Context): return embeddings_job.model_dump() + def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any]: + if event["status"] == "error": + self.log.info("status: %s, msg: %s", event["status"], event["message"]) + raise AirflowException(event["message"]) + if self.retrieve_result and event["job"].get("error") is None: + job_results = self._prepare_results_for_xcom(BatchJob(**event["job"])) + context["ti"].xcom_push(key="job_results", value=job_results) + return event["job"] + class GenAIGeminiUploadFileOperator(GoogleCloudBaseOperator): """ - Get file uploaded to Gemini Files API. + Upload file to Gemini Files API. The Files API lets you store up to 20GB of files per project, with each file not exceeding 2GB in size. Supported types are audio files, images, videos, documents, and others. Files are stored for 48 hours. diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/gen_ai.py b/providers/google/src/airflow/providers/google/cloud/triggers/gen_ai.py new file mode 100644 index 0000000000000..8c48b42d43446 --- /dev/null +++ b/providers/google/src/airflow/providers/google/cloud/triggers/gen_ai.py @@ -0,0 +1,284 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import asyncio +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from airflow.providers.google.cloud.hooks.gen_ai import BatchJobStatus, GenAIGeminiAPIAsyncHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + +if TYPE_CHECKING: + from google.genai.types import CreateBatchJobConfig + + +class GenAIGeminiCreateBatchJobTrigger(BaseTrigger): + """ + Trigger that creates Gemini Batch Job and waiting for execution. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param model: Required. The name of the publisher model to use for Batch job. + :param gemini_api_key: Required. Key to interact with Gemini Batch API. + :param input_source: Required. Source of requests, could be inline requests or file name. + :param results_folder: Optional. Path to a folder on local machine where file with results will be saved. + :param create_batch_job_config: Optional. Config for batch job creation. + :param retrieve_result: Optional. Push the result to XCom. If the input_source is inline, this pushes + the execution result. If a file name is specified, this pushes the output file path. + :param polling_interval: Optional. The interval, in seconds, to poll the job status. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + def __init__( + self, + project_id: str, + location: str, + model: str, + input_source: list | str, + gemini_api_key: str, + create_batch_job_config: CreateBatchJobConfig | dict | None = None, + results_folder: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + retrieve_result: bool = False, + polling_interval: int = 30, + ): + super().__init__() + + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.model = model + self.input_source = input_source + self.create_batch_job_config = create_batch_job_config + self.gemini_api_key = gemini_api_key + self.retrieve_result = retrieve_result + self.polling_interval = polling_interval + self.results_folder = results_folder + + def _get_async_hook(self) -> GenAIGeminiAPIAsyncHook: + return GenAIGeminiAPIAsyncHook(gemini_api_key=self.gemini_api_key) + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize class arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.gen_ai.GenAIGeminiCreateBatchJobTrigger", + { + "project_id": self.project_id, + "location": self.location, + "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, + "model": self.model, + "input_source": self.input_source, + "create_batch_job_config": self.create_batch_job_config, + "gemini_api_key": self.gemini_api_key, + "retrieve_result": self.retrieve_result, + "polling_interval": self.polling_interval, + "results_folder": self.results_folder, + }, + ) + + async def run(self): + """ + Loop until the job reaches successful final or error state. + + Yields a TriggerEvent with success status, if the job reaches successful state. + + Yields a TriggerEvent with error status, if the client returns an unexpected terminal + job status or any exception is raised while looping. + + In any other case the Trigger will wait for a specified amount of time + stored in self.polling_interval variable. + """ + try: + hook = self._get_async_hook() + job = await hook.create_batch_job( + model=self.model, + source=self.input_source, + create_batch_job_config=self.create_batch_job_config, + ) + while True: + job = await hook.get_batch_job(job_name=job.name) + if job.state.name == BatchJobStatus.SUCCEEDED.value: + self.log.info("Job execution completed") + yield TriggerEvent( + { + "status": "success", + "message": "Job completed", + "job_name": job.name, + } + ) + return + if job.state.name in [ + BatchJobStatus.FAILED.value, + BatchJobStatus.EXPIRED.value, + BatchJobStatus.CANCELLED.value, + ]: + self.log.error("Job execution was not completed!") + yield TriggerEvent( + {"status": "error", "message": f"Job {job.name} execution was not completed!"} + ) + return + self.log.info( + "Waiting for job execution, polling interval: %s seconds, current state: %s", + self.polling_interval, + job.state.name, + ) + await asyncio.sleep(self.polling_interval) + + except Exception as e: + self.log.exception("Exception occurred while checking for job completion.") + yield TriggerEvent({"status": "error", "message": str(e)}) + + +class GenAIGeminiCreateEmbeddingsBatchJobTrigger(BaseTrigger): + """ + Trigger that creates Gemini Embeddings Batch Job and waiting for execution. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param model: Required. The name of the publisher model to use for Batch job. + :param gemini_api_key: Required. Key to interact with Gemini Batch API. + :param input_source: Required. Source of requests, could be inline requests or file name. + :param results_folder: Optional. Path to a folder on local machine where file with results will be saved. + :param create_embeddings_config: Optional. Config for batch job creation. + :param retrieve_result: Optional. Push the result to XCom. If the input_source is inline, this pushes + the execution result. If a file name is specified, this pushes the output file path. + :param polling_interval: Optional. The interval, in seconds, to poll the job status. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + def __init__( + self, + project_id: str, + location: str, + model: str, + input_source: dict | str, + gemini_api_key: str, + create_embeddings_config: CreateBatchJobConfig | dict | None = None, + results_folder: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + retrieve_result: bool = False, + polling_interval: int = 30, + ): + super().__init__() + + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.model = model + self.input_source = input_source + self.create_embeddings_config = create_embeddings_config + self.gemini_api_key = gemini_api_key + self.retrieve_result = retrieve_result + self.polling_interval = polling_interval + self.results_folder = results_folder + + def _get_async_hook(self) -> GenAIGeminiAPIAsyncHook: + return GenAIGeminiAPIAsyncHook(gemini_api_key=self.gemini_api_key) + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize class arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.gen_ai.GenAIGeminiCreateEmbeddingsBatchJobTrigger", + { + "project_id": self.project_id, + "location": self.location, + "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, + "model": self.model, + "input_source": self.input_source, + "create_embeddings_config": self.create_embeddings_config, + "gemini_api_key": self.gemini_api_key, + "retrieve_result": self.retrieve_result, + "polling_interval": self.polling_interval, + "results_folder": self.results_folder, + }, + ) + + async def run(self): + """ + Loop until the job reaches successful final or error state. + + Yields a TriggerEvent with success status, if the job reaches successful state. + + Yields a TriggerEvent with error status, if the client returns an unexpected terminal + job status or any exception is raised while looping. + + In any other case the Trigger will wait for a specified amount of time + stored in self.polling_interval variable. + """ + try: + hook = self._get_async_hook() + job = await hook.create_embeddings_batch_job( + model=self.model, + source=self.input_source, + create_embeddings_config=self.create_embeddings_config, + ) + while True: + job = await hook.get_batch_job(job_name=job.name) + if job.state.name == BatchJobStatus.SUCCEEDED.value: + self.log.info("Job execution completed") + yield TriggerEvent( + { + "status": "success", + "message": "Job completed", + "job": job.model_dump(), + } + ) + return + if job.state.name in [ + BatchJobStatus.FAILED.value, + BatchJobStatus.EXPIRED.value, + BatchJobStatus.CANCELLED.value, + ]: + self.log.error("Job execution was not completed!") + yield TriggerEvent( + {"status": "error", "message": f"Job {job.name} execution was not completed!"} + ) + return + self.log.info( + "Waiting for job execution, polling interval: %s seconds, current state: %s", + self.polling_interval, + job.state.name, + ) + await asyncio.sleep(self.polling_interval) + + except Exception as e: + self.log.exception("Exception occurred while checking for job completion.") + yield TriggerEvent({"status": "error", "message": str(e)}) diff --git a/providers/google/src/airflow/providers/google/get_provider_info.py b/providers/google/src/airflow/providers/google/get_provider_info.py index 12d473d44a2e0..1defae1ed4fe7 100644 --- a/providers/google/src/airflow/providers/google/get_provider_info.py +++ b/providers/google/src/airflow/providers/google/get_provider_info.py @@ -1125,6 +1125,10 @@ def get_provider_info(): "integration-name": "Google Vertex AI", "python-modules": ["airflow.providers.google.cloud.triggers.vertex_ai"], }, + { + "integration-name": "Google Cloud Generative AI", + "python-modules": ["airflow.providers.google.cloud.triggers.gen_ai"], + }, ], "transfers": [ { diff --git a/providers/google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py b/providers/google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py index d942466c17f9e..3b14b328b335c 100644 --- a/providers/google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py +++ b/providers/google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py @@ -188,6 +188,22 @@ def get_gemini_api_key(): ) # [END how_to_cloud_gen_ai_batch_api_create_batch_job_with_inlined_requests_task] + # [START how_to_cloud_gen_ai_batch_api_create_batch_job_with_inlined_requests_deferrable_task] + create_batch_job_using_inlined_requests_deferrable = GenAIGeminiCreateBatchJobOperator( + task_id="create_batch_job_using_inlined_requests_deferrable_task", + project_id=PROJECT_ID, + location=REGION, + model="gemini-3-pro-preview", + gemini_api_key=GEMINI_XCOM_API_KEY, + create_batch_job_config={ + "display_name": "deferrable-inlined-requests-batch-job", + }, + input_source=INLINED_REQUESTS_FOR_BATCH_JOB, + retrieve_result=True, + deferrable=True, + ) + # [END how_to_cloud_gen_ai_batch_api_create_batch_job_with_inlined_requests_deferrable_task] + # [START how_to_cloud_gen_ai_batch_api_create_batch_job_with_file_task] create_batch_job_using_file = GenAIGeminiCreateBatchJobOperator( task_id="create_batch_job_using_file_task", @@ -235,6 +251,23 @@ def get_gemini_api_key(): ) # [END how_to_cloud_gen_ai_batch_api_create_embeddings_with_file_task] + # [START how_to_cloud_gen_ai_batch_api_create_embeddings_with_file_deferrable_task] + create_embeddings_job_using_file_deferrable = GenAIGeminiCreateEmbeddingsBatchJobOperator( + task_id="create_embeddings_job_using_file_deferrable_task", + project_id=PROJECT_ID, + location=REGION, + model="gemini-embedding-001", + retrieve_result=True, + deferrable=True, + gemini_api_key=GEMINI_XCOM_API_KEY, + create_embeddings_config={ + "display_name": "deferrable-file-upload-embeddings-job", + }, + input_source=UPLOADED_EMBEDDINGS_FILE_NAME, + results_folder=PATH_TO_SAVE_RESULTS, + ) + # [END how_to_cloud_gen_ai_batch_api_create_embeddings_with_file_deferrable_task] + # [START how_to_cloud_gen_ai_batch_api_get_batch_job_task] get_batch_job = GenAIGeminiGetBatchJobOperator( task_id="get_batch_job_task", @@ -318,6 +351,8 @@ def get_gemini_api_key(): create_batch_job_using_file, create_embeddings_job_using_file, create_embeddings_job_using_inlined_requests, + create_batch_job_using_inlined_requests_deferrable, + create_embeddings_job_using_file_deferrable, ] >> get_batch_job >> list_batch_jobs diff --git a/providers/google/tests/unit/google/cloud/hooks/test_gen_ai.py b/providers/google/tests/unit/google/cloud/hooks/test_gen_ai.py index 095c54664618f..e0ac0fb327b1c 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_gen_ai.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_gen_ai.py @@ -18,7 +18,9 @@ from __future__ import annotations from unittest import mock +from unittest.mock import AsyncMock +import pytest from google.genai.types import ( Content, CreateCachedContentConfig, @@ -29,6 +31,7 @@ ) from airflow.providers.google.cloud.hooks.gen_ai import ( + GenAIGeminiAPIAsyncHook, GenAIGeminiAPIHook, GenAIGenerativeModelHook, ) @@ -375,3 +378,66 @@ def test_delete_file(self, mock_get_client): self.hook.delete_file(file_name=TEST_FILE_NAME) client_mock.files.delete.assert_called_once_with(name=TEST_FILE_NAME) + + +def mock_init(*args, **kwargs): + pass + + +class TestGenAIGeminiAPIAsyncHook: + def setup_method(self, method): + with mock.patch(BASE_STRING.format("GoogleBaseAsyncHook.__init__"), new=mock_init): + self.hook = GenAIGeminiAPIAsyncHook(gemini_api_key=TEST_API_KEY) + + @pytest.mark.asyncio + @mock.patch(GENERATIVE_MODEL_STRING.format("GenAIGeminiAPIAsyncHook.get_async_client")) + async def test_get_job(self, mock_client): + mock_async_client = AsyncMock() + mock_client.return_value = mock_async_client + + await self.hook.get_batch_job( + job_name=TEST_JOB_NAME, + ) + + mock_client.assert_called_once() + mock_async_client.batches.get.assert_called_once_with( + name=TEST_JOB_NAME, + ) + + @pytest.mark.asyncio + @mock.patch(GENERATIVE_MODEL_STRING.format("GenAIGeminiAPIAsyncHook.get_async_client")) + async def test_create_batch_job(self, mock_client): + mock_async_client = AsyncMock() + mock_client.return_value = mock_async_client + + await self.hook.create_batch_job( + model=TEST_MODEL, + source=TEST_BATCH_JOB_SOURCE_INLINE, + create_batch_job_config=TEST_CREATE_BATCH_JOB_CONFIG, + ) + + mock_client.assert_called_once() + mock_async_client.batches.create.assert_called_once_with( + model=TEST_MODEL, + src=TEST_BATCH_JOB_SOURCE_INLINE, + config=TEST_CREATE_BATCH_JOB_CONFIG, + ) + + @pytest.mark.asyncio + @mock.patch(GENERATIVE_MODEL_STRING.format("GenAIGeminiAPIAsyncHook.get_async_client")) + async def test_create_embeddings_job(self, mock_client): + mock_async_client = AsyncMock() + mock_client.return_value = mock_async_client + + await self.hook.create_embeddings_batch_job( + model=TEST_MODEL, + source=TEST_EMBEDDINGS_JOB_SOURCE_INLINE, + create_embeddings_config=TEST_CREATE_BATCH_JOB_CONFIG, + ) + + mock_client.assert_called_once() + mock_async_client.batches.create_embeddings.assert_called_once_with( + model=TEST_MODEL, + src={"inlined_requests": TEST_EMBEDDINGS_JOB_SOURCE_INLINE}, + config=TEST_CREATE_BATCH_JOB_CONFIG, + ) diff --git a/providers/google/tests/unit/google/cloud/triggers/test_gen_ai.py b/providers/google/tests/unit/google/cloud/triggers/test_gen_ai.py new file mode 100644 index 0000000000000..6082b899bcd44 --- /dev/null +++ b/providers/google/tests/unit/google/cloud/triggers/test_gen_ai.py @@ -0,0 +1,261 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import asyncio +import logging +from unittest import mock + +import pytest + +from airflow.models import Connection +from airflow.providers.google.cloud.hooks.gen_ai import BatchJobStatus +from airflow.providers.google.cloud.triggers.gen_ai import ( + GenAIGeminiCreateBatchJobTrigger, + GenAIGeminiCreateEmbeddingsBatchJobTrigger, +) +from airflow.triggers.base import TriggerEvent + +BATCH_JOB_INLINED_REQUESTS = [ + {"contents": [{"parts": [{"text": "Tell me a one-sentence joke."}], "role": "user"}]}, + {"contents": [{"parts": [{"text": "Why is the sky blue?"}], "role": "user"}]}, +] + +EMBEDDINGS_JOB_INLINED_REQUESTS = { + "contents": [{"parts": [{"text": "Why is the sky blue?"}], "role": "user"}] +} +GEMINI_API_KEY = "test-key" +GEMINI_MODEL = "test-gemini-model" +BATCH_JOB_NAME = "test-name" +FILE_NAME = "test-file" +FILE_PATH = "test/path/to/file" +TASK_ID = "test_task_id" +GCP_PROJECT = "test-project" +GCP_LOCATION = "test-location" +GCP_CONN_ID = "test-conn" +IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] +BATCH_JOB_CONFIG = {"display_name": "test-batch-job"} +RESULTS_FOLDER = "test/results/folder" + + +@pytest.fixture +@mock.patch( + "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection", + return_value=Connection(conn_id=GCP_CONN_ID), +) +def create_batch_job_trigger(mock_conn): + return GenAIGeminiCreateBatchJobTrigger( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + model=GEMINI_MODEL, + input_source=BATCH_JOB_INLINED_REQUESTS, + create_batch_job_config=BATCH_JOB_CONFIG, + gemini_api_key=GEMINI_API_KEY, + retrieve_result=True, + polling_interval=30, + results_folder=RESULTS_FOLDER, + ) + + +@pytest.fixture +@mock.patch( + "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection", + return_value=Connection(conn_id=GCP_CONN_ID), +) +def create_embeddings_batch_job_trigger(mock_conn): + return GenAIGeminiCreateEmbeddingsBatchJobTrigger( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + model=GEMINI_MODEL, + input_source=BATCH_JOB_INLINED_REQUESTS, + create_embeddings_config=BATCH_JOB_CONFIG, + gemini_api_key=GEMINI_API_KEY, + retrieve_result=True, + polling_interval=30, + results_folder=RESULTS_FOLDER, + ) + + +class TestGenAIGeminiCreateBatchJobTrigger: + def test_serialize(self, create_batch_job_trigger): + actual_data = create_batch_job_trigger.serialize() + expected_data = ( + "airflow.providers.google.cloud.triggers.gen_ai.GenAIGeminiCreateBatchJobTrigger", + { + "project_id": GCP_PROJECT, + "location": GCP_LOCATION, + "gcp_conn_id": GCP_CONN_ID, + "impersonation_chain": IMPERSONATION_CHAIN, + "model": GEMINI_MODEL, + "input_source": BATCH_JOB_INLINED_REQUESTS, + "create_batch_job_config": BATCH_JOB_CONFIG, + "gemini_api_key": GEMINI_API_KEY, + "retrieve_result": True, + "polling_interval": 30, + "results_folder": RESULTS_FOLDER, + }, + ) + assert actual_data == expected_data + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.gen_ai.GenAIGeminiAPIAsyncHook.create_batch_job") + @mock.patch("airflow.providers.google.cloud.hooks.gen_ai.GenAIGeminiAPIAsyncHook.get_batch_job") + async def test_run_loop_return_success_event( + self, mock_job_status, mock_create_batch_job, create_batch_job_trigger + ): + mock_job_status.return_value.state.name = BatchJobStatus.SUCCEEDED.value + mock_job_status.return_value.name = BATCH_JOB_CONFIG["display_name"] + + expected_event = TriggerEvent( + { + "status": "success", + "message": "Job completed", + "job_name": BATCH_JOB_CONFIG["display_name"], + } + ) + actual_event = await create_batch_job_trigger.run().asend(None) + + mock_create_batch_job.assert_called_once() + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.gen_ai.GenAIGeminiAPIAsyncHook.create_batch_job") + @mock.patch("airflow.providers.google.cloud.hooks.gen_ai.GenAIGeminiAPIAsyncHook.get_batch_job") + async def test_run_loop_return_failed_event( + self, mock_job_status, mock_create_batch_job, create_batch_job_trigger + ): + mock_job_status.return_value.state.name = BatchJobStatus.FAILED.value + mock_job_status.return_value.name = BATCH_JOB_NAME + + expected_event = TriggerEvent( + {"status": "error", "message": f"Job {BATCH_JOB_NAME} execution was not completed!"} + ) + actual_event = await create_batch_job_trigger.run().asend(None) + + mock_create_batch_job.assert_called_once() + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.gen_ai.GenAIGeminiAPIAsyncHook.create_batch_job") + @mock.patch("airflow.providers.google.cloud.hooks.gen_ai.GenAIGeminiAPIAsyncHook.get_batch_job") + async def test_run_loop_is_still_running( + self, mock_job_status, mock_create_batch_job, create_batch_job_trigger, caplog + ): + mock_job_status.return_value.state.name = BatchJobStatus.RUNNING.value + caplog.set_level(logging.INFO) + + task = asyncio.create_task(create_batch_job_trigger.run().__anext__()) + await asyncio.sleep(5) + + assert not task.done() + assert ( + f"Waiting for job execution, polling interval: 30 seconds, current state: {BatchJobStatus.RUNNING.value}" + in caplog.text + ) + # cancel the task to suppress test warnings + task.cancel() + + +class TestGenAIGeminiCreateEmbeddingsBatchJobTrigger: + def test_serialize(self, create_embeddings_batch_job_trigger): + actual_data = create_embeddings_batch_job_trigger.serialize() + expected_data = ( + "airflow.providers.google.cloud.triggers.gen_ai.GenAIGeminiCreateEmbeddingsBatchJobTrigger", + { + "project_id": GCP_PROJECT, + "location": GCP_LOCATION, + "gcp_conn_id": GCP_CONN_ID, + "impersonation_chain": IMPERSONATION_CHAIN, + "model": GEMINI_MODEL, + "input_source": BATCH_JOB_INLINED_REQUESTS, + "create_embeddings_config": BATCH_JOB_CONFIG, + "gemini_api_key": GEMINI_API_KEY, + "retrieve_result": True, + "polling_interval": 30, + "results_folder": RESULTS_FOLDER, + }, + ) + assert actual_data == expected_data + + @pytest.mark.asyncio + @mock.patch( + "airflow.providers.google.cloud.hooks.gen_ai.GenAIGeminiAPIAsyncHook.create_embeddings_batch_job" + ) + @mock.patch("airflow.providers.google.cloud.hooks.gen_ai.GenAIGeminiAPIAsyncHook.get_batch_job") + async def test_run_loop_return_success_event( + self, mock_job_status, mock_create_embeddings_batch_job, create_embeddings_batch_job_trigger + ): + test_job_model_dump = {"id": "test_job_id", "status": "succeeded"} + mock_job_status.return_value.state.name = BatchJobStatus.SUCCEEDED.value + mock_job_status.return_value.model_dump = mock.Mock(return_value=test_job_model_dump) + + expected_event = TriggerEvent( + { + "status": "success", + "message": "Job completed", + "job": test_job_model_dump, + } + ) + actual_event = await create_embeddings_batch_job_trigger.run().asend(None) + mock_create_embeddings_batch_job.assert_called_once() + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch( + "airflow.providers.google.cloud.hooks.gen_ai.GenAIGeminiAPIAsyncHook.create_embeddings_batch_job" + ) + @mock.patch("airflow.providers.google.cloud.hooks.gen_ai.GenAIGeminiAPIAsyncHook.get_batch_job") + async def test_run_loop_return_failed_event( + self, mock_job_status, mock_create_embeddings_batch_job, create_embeddings_batch_job_trigger + ): + mock_job_status.return_value.state.name = BatchJobStatus.FAILED.value + mock_job_status.return_value.name = BATCH_JOB_NAME + + expected_event = TriggerEvent( + {"status": "error", "message": f"Job {BATCH_JOB_NAME} execution was not completed!"} + ) + actual_event = await create_embeddings_batch_job_trigger.run().asend(None) + + mock_create_embeddings_batch_job.assert_called_once() + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch( + "airflow.providers.google.cloud.hooks.gen_ai.GenAIGeminiAPIAsyncHook.create_embeddings_batch_job" + ) + @mock.patch("airflow.providers.google.cloud.hooks.gen_ai.GenAIGeminiAPIAsyncHook.get_batch_job") + async def test_run_loop_is_still_running( + self, mock_job_status, mock_create_embeddings_batch_job, create_embeddings_batch_job_trigger, caplog + ): + mock_job_status.return_value.state.name = BatchJobStatus.RUNNING.value + caplog.set_level(logging.INFO) + + task = asyncio.create_task(create_embeddings_batch_job_trigger.run().__anext__()) + await asyncio.sleep(5) + + assert not task.done() + assert ( + f"Waiting for job execution, polling interval: 30 seconds, current state: {BatchJobStatus.RUNNING.value}" + in caplog.text + ) + # cancel the task to suppress test warnings + task.cancel()