Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions providers/google/docs/operators/cloud/gen_ai.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ If you use file take a look at this example:
:start-after: [START how_to_cloud_gen_ai_batch_api_create_batch_job_with_file_task]
:end-before: [END how_to_cloud_gen_ai_batch_api_create_batch_job_with_file_task]

Operator could run in deferrable mode:

.. exampleinclude:: /../../google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_gen_ai_batch_api_create_batch_job_with_inlined_requests_deferrable_task]
:end-before: [END how_to_cloud_gen_ai_batch_api_create_batch_job_with_inlined_requests_deferrable_task]

Get batch job
"""""""""""""

Expand Down Expand Up @@ -220,6 +228,14 @@ If you use file take a look at this example:
:start-after: [START how_to_cloud_gen_ai_batch_api_create_embeddings_with_file_task]
:end-before: [END how_to_cloud_gen_ai_batch_api_create_embeddings_with_file_task]

Operator could run in deferrable mode:

.. exampleinclude:: /../../google/tests/system/google/cloud/gen_ai/example_gen_ai_gemini_batch_api.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_gen_ai_batch_api_create_embeddings_with_file_deferrable_task]
:end-before: [END how_to_cloud_gen_ai_batch_api_create_embeddings_with_file_deferrable_task]

Interacting with Gemini Files API
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
3 changes: 3 additions & 0 deletions providers/google/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,9 @@ triggers:
- integration-name: Google Vertex AI
python-modules:
- airflow.providers.google.cloud.triggers.vertex_ai
- integration-name: Google Cloud Generative AI
python-modules:
- airflow.providers.google.cloud.triggers.gen_ai

transfers:
- source-integration-name: Presto
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@

from __future__ import annotations

import enum
import time
from typing import TYPE_CHECKING, Any

from google import genai

from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import (
PROVIDE_PROJECT_ID,
GoogleBaseAsyncHook,
GoogleBaseHook,
)

if TYPE_CHECKING:
from google.genai.pagers import Pager
Expand All @@ -49,6 +54,17 @@
)


class BatchJobStatus(enum.Enum):
"""Possible states of batch job in Gemini Batch API."""

SUCCEEDED = "JOB_STATE_SUCCEEDED"
PENDING = "JOB_STATE_PENDING"
FAILED = "JOB_STATE_FAILED"
RUNNING = "JOB_STATE_RUNNING"
CANCELLED = "JOB_STATE_CANCELLED"
EXPIRED = "JOB_STATE_EXPIRED"


class GenAIGenerativeModelHook(GoogleBaseHook):
"""Class for Google Cloud Generative AI Vertex AI hook."""

Expand Down Expand Up @@ -382,3 +398,79 @@ def delete_file(self, file_name: str) -> DeleteFileResponse:
client = self.get_genai_client()
resp = client.files.delete(name=file_name)
return resp


class GenAIGeminiAPIAsyncHook(GoogleBaseAsyncHook):
"""Class for Google Cloud Generative AI Gemini Developer Async API hook."""

sync_hook_class = GenAIGeminiAPIHook

def __init__(self, gemini_api_key: str, **kwargs) -> None:
super().__init__(**kwargs)
self.gemini_api_key = gemini_api_key

async def get_async_client(self):
return genai.Client(
api_key=self.gemini_api_key,
vertexai=False,
).aio

async def create_batch_job(
self,
model: str,
source: list | str,
create_batch_job_config: CreateBatchJobConfig | dict | None = None,
) -> BatchJob:
"""
Create batch job asynchronously using Gemini Batch API to process large-scale, non-urgent tasks.

:param model: Required. Gemini model name to process requests.
:param source: Required. Requests that will be sent to chosen model.
Can be in format of Inline requests or file name.
:param create_batch_job_config: Optional. Configuration parameters for batch job.
"""
async_client = await self.get_async_client()
async_job = await async_client.batches.create(model=model, src=source, config=create_batch_job_config)
return async_job

async def get_batch_job(
self,
job_name: str,
) -> BatchJob:
"""
Get batch job using Gemini Batch API asynchronously.

:param job_name: Required. Batch job name.
"""
async_client = await self.get_async_client()
resp = await async_client.batches.get(name=job_name)
return resp

async def create_embeddings_batch_job(
self,
model: str,
source: dict | str,
create_embeddings_config: CreateBatchJobConfig | dict | None = None,
) -> BatchJob:
"""
Create batch job for embeddings asynchronously using Gemini Batch API to process large-scale, non-urgent tasks.

:param model: Required. Gemini model name to process requests.
:param source: Required. Requests that will be sent to chosen model.
Can be in format of Inline requests or file name.
:param create_embeddings_config: Optional. Configuration parameters for embeddings batch job.
"""
async_client = await self.get_async_client()
input_type = "inlined_requests"

if isinstance(source, str):
input_type = "file_name"

self.log.info("Using %s to create embeddings", input_type)

resp = await async_client.batches.create_embeddings(
model=model,
src={input_type: source},
config=create_embeddings_config,
)
return resp
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,26 @@

from __future__ import annotations

import enum
import os.path
import time
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

from google.genai.errors import ClientError
from google.genai.types import BatchJob

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.gen_ai import (
BatchJobStatus,
GenAIGeminiAPIHook,
GenAIGenerativeModelHook,
)
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.cloud.triggers.gen_ai import (
GenAIGeminiCreateBatchJobTrigger,
GenAIGeminiCreateEmbeddingsBatchJobTrigger,
)

if TYPE_CHECKING:
from google.genai.types import (
Expand Down Expand Up @@ -398,17 +404,6 @@ def execute(self, context: Context):
return cached_content_name


class BatchJobStatus(enum.Enum):
"""Possible states of batch job in Gemini Batch API."""

SUCCEEDED = "JOB_STATE_SUCCEEDED"
PENDING = "JOB_STATE_PENDING"
FAILED = "JOB_STATE_FAILED"
RUNNING = "JOB_STATE_RUNNING"
CANCELLED = "JOB_STATE_CANCELLED"
EXPIRED = "JOB_STATE_EXPIRED"


class GenAIGeminiCreateBatchJobOperator(GoogleCloudBaseOperator):
"""
Create Batch job using Gemini Batch API. Use to generate model response for several requests.
Expand All @@ -433,6 +428,7 @@ class GenAIGeminiCreateBatchJobOperator(GoogleCloudBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Optional. Run operator in the deferrable mode.
"""

template_fields = (
Expand Down Expand Up @@ -460,6 +456,7 @@ def __init__(
retrieve_result: bool = False,
wait_until_complete: bool = False,
polling_interval: int = 30,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -476,9 +473,12 @@ def __init__(
self.wait_until_complete = wait_until_complete
self.polling_interval = polling_interval
self.results_folder = results_folder
self.deferrable = deferrable

if self.retrieve_result and not self.wait_until_complete:
raise AirflowException("Retrieving results is possible only if wait_until_complete set to True")
if self.retrieve_result and not (self.wait_until_complete or self.deferrable):
raise AirflowException(
"Retrieving results is possible only if wait_until_complete set to True or in deferrable mode"
)
if self.results_folder and not isinstance(self.input_source, str):
raise AirflowException("results_folder works only when input_source is file name")
if self.results_folder and not os.path.exists(os.path.abspath(self.results_folder)):
Expand Down Expand Up @@ -533,13 +533,33 @@ def _prepare_results_for_xcom(self, job):

return results

def execute(self, context: Context):
self.hook = GenAIGeminiAPIHook(
@property
def hook(self):
return GenAIGeminiAPIHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
gemini_api_key=self.gemini_api_key,
)

def execute(self, context: Context):
if self.deferrable:
self.defer(
trigger=GenAIGeminiCreateBatchJobTrigger(
project_id=self.project_id,
location=self.location,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
model=self.model,
input_source=self.input_source,
create_batch_job_config=self.create_batch_job_config,
gemini_api_key=self.gemini_api_key,
retrieve_result=self.retrieve_result,
polling_interval=self.polling_interval,
results_folder=self.results_folder,
),
method_name="execute_complete",
)

try:
job = self.hook.create_batch_job(
model=self.model,
Expand All @@ -560,6 +580,16 @@ def execute(self, context: Context):

return dict(job)

def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any]:
if event["status"] == "error":
self.log.info("status: %s, msg: %s", event["status"], event["message"])
raise AirflowException(event["message"])
job = self.hook.get_batch_job(event["job_name"])
if self.retrieve_result and job.error is None:
job_results = self._prepare_results_for_xcom(job)
context["ti"].xcom_push(key="job_results", value=job_results)
return dict(job)


class GenAIGeminiGetBatchJobOperator(GoogleCloudBaseOperator):
"""
Expand Down Expand Up @@ -841,6 +871,7 @@ class GenAIGeminiCreateEmbeddingsBatchJobOperator(GoogleCloudBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Optional. Run operator in the deferrable mode.
"""

template_fields = (
Expand Down Expand Up @@ -868,6 +899,7 @@ def __init__(
wait_until_complete: bool = False,
retrieve_result: bool = False,
polling_interval: int = 30,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -884,9 +916,12 @@ def __init__(
self.retrieve_result = retrieve_result
self.polling_interval = polling_interval
self.results_folder = results_folder
self.deferrable = deferrable

if self.retrieve_result and not self.wait_until_complete:
raise AirflowException("Retrieving results is possible only if wait_until_complete set to True")
if self.retrieve_result and not (self.wait_until_complete or self.deferrable):
raise AirflowException(
"Retrieving results is possible only if wait_until_complete set to True or in deferrable mode"
)
if self.results_folder and not isinstance(self.input_source, str):
raise AirflowException("results_folder works only when input_source is file name")
if self.results_folder and not os.path.exists(os.path.abspath(self.results_folder)):
Expand Down Expand Up @@ -941,13 +976,33 @@ def _prepare_results_for_xcom(self, job):

return results

def execute(self, context: Context):
self.hook = GenAIGeminiAPIHook(
@property
def hook(self):
return GenAIGeminiAPIHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
gemini_api_key=self.gemini_api_key,
)

def execute(self, context: Context):
if self.deferrable:
self.defer(
trigger=GenAIGeminiCreateEmbeddingsBatchJobTrigger(
project_id=self.project_id,
location=self.location,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
model=self.model,
input_source=self.input_source,
create_embeddings_config=self.create_embeddings_config,
gemini_api_key=self.gemini_api_key,
retrieve_result=self.retrieve_result,
polling_interval=self.polling_interval,
results_folder=self.results_folder,
),
method_name="execute_complete",
)

try:
embeddings_job = self.hook.create_embeddings(
model=self.model,
Expand All @@ -968,10 +1023,19 @@ def execute(self, context: Context):

return embeddings_job.model_dump()

def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any]:
if event["status"] == "error":
self.log.info("status: %s, msg: %s", event["status"], event["message"])
raise AirflowException(event["message"])
if self.retrieve_result and event["job"].get("error") is None:
job_results = self._prepare_results_for_xcom(BatchJob(**event["job"]))
context["ti"].xcom_push(key="job_results", value=job_results)
return event["job"]


class GenAIGeminiUploadFileOperator(GoogleCloudBaseOperator):
"""
Get file uploaded to Gemini Files API.
Upload file to Gemini Files API.

The Files API lets you store up to 20GB of files per project, with each file not exceeding 2GB in size.
Supported types are audio files, images, videos, documents, and others. Files are stored for 48 hours.
Expand Down
Loading
Loading