diff --git a/providers/google/README.rst b/providers/google/README.rst index 6c252d575f974..842ead044e002 100644 --- a/providers/google/README.rst +++ b/providers/google/README.rst @@ -76,6 +76,7 @@ PIP package Version required ``google-api-python-client`` ``>=2.0.2`` ``google-auth`` ``>=2.29.0`` ``google-auth-httplib2`` ``>=0.0.1`` +``google-genai`` ``>=2.8.0`` ``google-cloud-aiplatform[evaluation]`` ``>=1.155.0`` ``ray[default]`` ``>=2.42.0; python_version < "3.13"`` ``ray[default]`` ``>=2.49.0; python_version >= "3.13" and python_version < "3.14"`` diff --git a/providers/google/docs/index.rst b/providers/google/docs/index.rst index ff860c759857c..90a49e2f6edda 100644 --- a/providers/google/docs/index.rst +++ b/providers/google/docs/index.rst @@ -129,6 +129,7 @@ PIP package Version required ``google-api-python-client`` ``>=2.0.2`` ``google-auth`` ``>=2.29.0`` ``google-auth-httplib2`` ``>=0.0.1`` +``google-genai`` ``>=2.8.0`` ``google-cloud-aiplatform[evaluation]`` ``>=1.155.0`` ``ray[default]`` ``>=2.42.0; python_version < "3.13"`` ``ray[default]`` ``>=2.49.0; python_version >= "3.13" and python_version < "3.14"`` diff --git a/providers/google/docs/operators/cloud/vertex_ai.rst b/providers/google/docs/operators/cloud/vertex_ai.rst index 246788c9e9302..f11b8787ef684 100644 --- a/providers/google/docs/operators/cloud/vertex_ai.rst +++ b/providers/google/docs/operators/cloud/vertex_ai.rst @@ -26,6 +26,71 @@ With Vertex AI, both AutoML training and custom training are available options. Whichever option you choose for training, you can save models, deploy models, and request predictions with Vertex AI. +Managing Agent Engines +^^^^^^^^^^^^^^^^^^^^^^ + +The operators below manage `Vertex AI Agent Engine +`__ resources. + +To create a Vertex AI Agent Engine you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.CreateAgentEngineOperator`. + +.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_create_agent_engine_operator] + :end-before: [END how_to_cloud_vertex_ai_create_agent_engine_operator] + +To get an Agent Engine you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.GetAgentEngineOperator`. + +.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_get_agent_engine_operator] + :end-before: [END how_to_cloud_vertex_ai_get_agent_engine_operator] + +To run a query job on an Agent Engine you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.RunQueryJobOperator`. +The operator uses the public ``run_query_job`` SDK method. The ``config`` parameter +can include ``query`` and ``output_gcs_uri``. The SDK writes query input and output +through Google Cloud Storage. By default, the operator waits for the query job to +complete and returns the serialized query job result. Set ``retrieve_result`` to +``True`` in ``check_config`` to return the query job result from Google Cloud Storage. + +.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_run_query_job_operator] + :end-before: [END how_to_cloud_vertex_ai_run_query_job_operator] + +The same operation can be performed in the deferrable mode. + +.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_run_query_job_operator_deferrable] + :end-before: [END how_to_cloud_vertex_ai_run_query_job_operator_deferrable] + +To update an Agent Engine you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.UpdateAgentEngineOperator`. + +.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_update_agent_engine_operator] + :end-before: [END how_to_cloud_vertex_ai_update_agent_engine_operator] + +To delete an Agent Engine you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.DeleteAgentEngineOperator`. +By default, the operator waits until the delete operation completes. + +.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_delete_agent_engine_operator] + :end-before: [END how_to_cloud_vertex_ai_delete_agent_engine_operator] + Creating Datasets ^^^^^^^^^^^^^^^^^ diff --git a/providers/google/provider.yaml b/providers/google/provider.yaml index 44b6ca5efd6e9..b72ddd525382d 100644 --- a/providers/google/provider.yaml +++ b/providers/google/provider.yaml @@ -615,6 +615,7 @@ operators: - airflow.providers.google.cloud.operators.vertex_ai.generative_model - airflow.providers.google.cloud.operators.vertex_ai.feature_store - airflow.providers.google.cloud.operators.vertex_ai.ray + - airflow.providers.google.cloud.operators.vertex_ai.agent_engine - integration-name: Google Data Studio python-modules: - airflow.providers.google.cloud.operators.looker @@ -901,6 +902,7 @@ hooks: - airflow.providers.google.cloud.hooks.vertex_ai.generative_model - airflow.providers.google.cloud.hooks.vertex_ai.prediction_service - airflow.providers.google.cloud.hooks.vertex_ai.feature_store + - airflow.providers.google.cloud.hooks.vertex_ai.agent_engine - airflow.providers.google.cloud.hooks.vertex_ai.ray - integration-name: Google Data Studio python-modules: diff --git a/providers/google/pyproject.toml b/providers/google/pyproject.toml index d42cc3d3a2802..d241551a9a240 100644 --- a/providers/google/pyproject.toml +++ b/providers/google/pyproject.toml @@ -80,6 +80,7 @@ dependencies = [ "google-api-python-client>=2.0.2", "google-auth>=2.29.0", "google-auth-httplib2>=0.0.1", + "google-genai>=2.8.0", # google-cloud-aiplatform doesn't install ray for python 3.12 (issue: https://github.com/googleapis/python-aiplatform/issues/5252). # Temporarily lock in ray 2.42.0 which is compatible with python 3.12 until linked issue is solved. # Remove the ray dependency as well as google-cloud-bigquery-storage once linked issue is fixed diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py new file mode 100644 index 0000000000000..3617147fd958d --- /dev/null +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -0,0 +1,406 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Vertex AI Agent Engine hook.""" + +from __future__ import annotations + +import time +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +import google.auth.transport.requests +from aiohttp import ClientSession +from asgiref.sync import sync_to_async +from vertexai import Client + +from airflow.providers.google.common.hooks.base_google import ( + PROVIDE_PROJECT_ID, + GoogleBaseAsyncHook, + GoogleBaseHook, +) + +if TYPE_CHECKING: + from vertexai._genai import types + + +VERTEX_AI_AGENT_ENGINE_API_VERSION = "v1beta1" +VERTEX_AI_AGENT_ENGINE_OPERATION_URL = ( + "https://{location}-aiplatform.googleapis.com/{api_version}/{operation_name}" +) +DEFAULT_AGENT_ENGINE_OPERATION_REQUEST_TIMEOUT = 60.0 + + +def extract_operation_id(operation_name: str) -> str: + """Extract the operation ID from a fully qualified operation name.""" + return operation_name.rstrip("/").split("/")[-1] + + +def serialize_value(value: Any) -> Any: + """Recursively convert SDK model objects to JSON-serializable types.""" + if hasattr(value, "model_dump"): + return value.model_dump(mode="json") + if isinstance(value, dict): + return {key: serialize_value(item) for key, item in value.items()} + if isinstance(value, list): + return [serialize_value(item) for item in value] + if isinstance(value, tuple): + return tuple(serialize_value(item) for item in value) + return value + + +class AgentEngineHook(GoogleBaseHook): + """ + Hook for Google Cloud Vertex AI Agent Engine APIs. + + Wraps the ``agent_engines`` module of the Vertex AI SDK client: + https://docs.cloud.google.com/python/docs/reference/agentplatform/latest/vertexai._genai.agent_engines.AgentEngines + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def get_agent_engine_client(self, project_id: str, location: str): + """Return the Vertex AI Agent Engine client.""" + return Client( + project=project_id, + location=location, + credentials=self.get_credentials(), + ).agent_engines + + @staticmethod + def build_agent_engine_name(project_id: str, location: str, agent_engine_id: str) -> str: + """Build a fully qualified Agent Engine resource name.""" + return f"projects/{project_id}/locations/{location}/reasoningEngines/{agent_engine_id}" + + @staticmethod + def build_operation_name(project_id: str, location: str, operation_id: str) -> str: + """Build a fully qualified Agent Engine operation name.""" + return f"projects/{project_id}/locations/{location}/operations/{operation_id}" + + @GoogleBaseHook.fallback_to_default_project_id + def create_agent_engine( + self, + location: str, + agent: Any | None = None, + config: types.AgentEngineConfigOrDict | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> types.AgentEngine: + """ + Create an Agent Engine. + + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param agent: Optional. The agent object to deploy. + :param config: Optional. Configuration for the Agent Engine. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. + """ + client = self.get_agent_engine_client(project_id=project_id, location=location) + return client.create(agent=agent, config=config) + + @GoogleBaseHook.fallback_to_default_project_id + def get_agent_engine( + self, + location: str, + agent_engine_id: str, + config: types.GetAgentEngineConfigOrDict | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> types.AgentEngine: + """ + Get an Agent Engine. + + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param agent_engine_id: Required. The Agent Engine ID. + :param config: Optional. Configuration for getting the Agent Engine. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. + """ + client = self.get_agent_engine_client(project_id=project_id, location=location) + name = self.build_agent_engine_name(project_id, location, agent_engine_id) + return client.get(name=name, config=config) + + @GoogleBaseHook.fallback_to_default_project_id + def run_query_job( + self, + location: str, + agent_engine_id: str, + config: types.RunQueryJobAgentEngineConfigOrDict | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> types.RunQueryJobResult: + """ + Run a query job on an Agent Engine. + + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param agent_engine_id: Required. The Agent Engine ID. + :param config: Optional. Configuration for the query job (``query``, ``output_gcs_uri``). + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. + """ + client = self.get_agent_engine_client(project_id=project_id, location=location) + name = self.build_agent_engine_name(project_id, location, agent_engine_id) + return client.run_query_job(name=name, config=config) + + @GoogleBaseHook.fallback_to_default_project_id + def check_query_agent_engine_job( + self, + location: str, + operation_id: str, + config: types.CheckQueryJobAgentEngineConfigOrDict | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> types.CheckQueryJobResult: + """ + Check a query job on an Agent Engine. + + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param operation_id: Required. The query job operation ID. + :param config: Optional. Configuration for checking the query job. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. + """ + client = self.get_agent_engine_client(project_id=project_id, location=location) + operation_name = self.build_operation_name(project_id, location, operation_id) + return client.check_query_job(name=operation_name, config=config) + + @GoogleBaseHook.fallback_to_default_project_id + def wait_for_query_agent_engine_job( + self, + location: str, + operation_id: str, + config: types.CheckQueryJobAgentEngineConfigOrDict | None = None, + poll_interval: float = 30, + timeout: float | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> types.CheckQueryJobResult: + """ + Wait until an Agent Engine query job completes. + + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param operation_id: Required. The query job operation ID. + :param config: Optional. Configuration for checking the query job. + :param poll_interval: Time, in seconds, to wait between checks. + :param timeout: Optional timeout, in seconds. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. + """ + start_time = time.monotonic() + operation_name = self.build_operation_name(project_id, location, operation_id) + while True: + query_job = self.check_query_agent_engine_job( + project_id=project_id, + location=location, + operation_id=operation_id, + config=config, + ) + status = getattr(query_job, "status", None) + if status == "SUCCESS": + return query_job + if status == "FAILED": + raise RuntimeError(f"Agent Engine query job {operation_name} failed.") + if status not in (None, "RUNNING"): + raise RuntimeError( + f"Agent Engine query job {operation_name} completed with unexpected status {status}." + ) + if timeout is not None and time.monotonic() - start_time >= timeout: + raise TimeoutError(f"Timed out waiting for Agent Engine query job {operation_name}") + self.log.info("Waiting for Agent Engine query job %s to complete.", operation_name) + time.sleep(poll_interval) + + @GoogleBaseHook.fallback_to_default_project_id + def update_agent_engine( + self, + location: str, + agent_engine_id: str, + config: types.AgentEngineConfigOrDict, + agent: Any | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> types.AgentEngine: + """ + Update an Agent Engine. + + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param agent_engine_id: Required. The Agent Engine ID. + :param config: Required. Configuration for the Agent Engine update. + :param agent: Optional. The updated agent object to deploy. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. + """ + client = self.get_agent_engine_client(project_id=project_id, location=location) + name = self.build_agent_engine_name(project_id, location, agent_engine_id) + return client.update(name=name, agent=agent, config=config) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_agent_engine( + self, + location: str, + agent_engine_id: str, + force: bool | None = None, + config: types.DeleteAgentEngineConfigOrDict | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> types.DeleteAgentEngineOperation: + """ + Delete an Agent Engine. + + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param agent_engine_id: Required. The Agent Engine ID. + :param force: Optional. Whether to forcefully delete child resources. Defaults to ``False`` + when not specified. + :param config: Optional. Additional deletion configuration. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. + """ + client = self.get_agent_engine_client(project_id=project_id, location=location) + name = self.build_agent_engine_name(project_id, location, agent_engine_id) + return client.delete(name=name, force=force, config=config) + + @GoogleBaseHook.fallback_to_default_project_id + def get_agent_engine_operation( + self, + location: str, + operation_id: str, + request_timeout: float | None = DEFAULT_AGENT_ENGINE_OPERATION_REQUEST_TIMEOUT, + project_id: str = PROVIDE_PROJECT_ID, + ) -> dict[str, Any]: + """ + Return a Vertex AI Agent Engine long-running operation. + + :param location: The ID of the Google Cloud location that the service belongs to. + :param operation_id: The Agent Engine operation ID. + :param request_timeout: Optional timeout, in seconds, for the operation request. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. + """ + operation_name = self.build_operation_name(project_id, location, operation_id) + url = VERTEX_AI_AGENT_ENGINE_OPERATION_URL.format( + location=location, + api_version=VERTEX_AI_AGENT_ENGINE_API_VERSION, + operation_name=operation_name, + ) + session = google.auth.transport.requests.AuthorizedSession(self.get_credentials()) + response = session.get(url, timeout=request_timeout) + response.raise_for_status() + return response.json() + + @GoogleBaseHook.fallback_to_default_project_id + def wait_for_agent_engine_operation( + self, + location: str, + operation_id: str, + poll_interval: float = 30, + timeout: float | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> None: + """ + Wait until an Agent Engine operation completes. + + :param location: The ID of the Google Cloud location that the service belongs to. + :param operation_id: The Agent Engine operation ID. + :param poll_interval: Time, in seconds, to wait between checks. + :param timeout: Optional timeout, in seconds. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. + """ + start_time = time.monotonic() + operation_name = self.build_operation_name(project_id, location, operation_id) + while True: + operation = self.get_agent_engine_operation( + project_id=project_id, + location=location, + operation_id=operation_id, + ) + if operation.get("done"): + if operation.get("error"): + raise RuntimeError( + f"Agent Engine operation {operation_name} failed: {operation['error']}" + ) + return + if timeout is not None and time.monotonic() - start_time >= timeout: + raise TimeoutError(f"Timed out waiting for Agent Engine operation {operation_name}") + self.log.info("Waiting for Agent Engine operation %s to complete.", operation_name) + time.sleep(poll_interval) + + +class AgentEngineAsyncHook(GoogleBaseAsyncHook): + """Async hook for Google Cloud Vertex AI Agent Engine APIs.""" + + sync_hook_class = AgentEngineHook + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ): + super().__init__( + gcp_conn_id=gcp_conn_id, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + async def get_agent_engine_operation( + self, + location: str, + operation_id: str, + request_timeout: float | None = DEFAULT_AGENT_ENGINE_OPERATION_REQUEST_TIMEOUT, + project_id: str = PROVIDE_PROJECT_ID, + ) -> dict[str, Any]: + """Return a Vertex AI Agent Engine long-running operation.""" + sync_hook = await self.get_sync_hook() + project_id = project_id or sync_hook.project_id + operation_name = self.sync_hook_class.build_operation_name(project_id, location, operation_id) + url = VERTEX_AI_AGENT_ENGINE_OPERATION_URL.format( + location=location, + api_version=VERTEX_AI_AGENT_ENGINE_API_VERSION, + operation_name=operation_name, + ) + credentials = sync_hook.get_credentials() + + if not credentials.valid: + credentials.refresh(google.auth.transport.requests.Request()) + + async with ClientSession() as session: + async with session.get( + url, + headers={"Authorization": f"Bearer {credentials.token}"}, + timeout=request_timeout, + ) as response: + response.raise_for_status() + return await response.json() + + async def check_query_agent_engine_job( + self, + location: str, + operation_id: str, + config: types.CheckQueryJobAgentEngineConfigOrDict | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> types.CheckQueryJobResult: + """Check a query job on an Agent Engine.""" + sync_hook = await self.get_sync_hook() + return await sync_to_async(sync_hook.check_query_agent_engine_job)( + project_id=project_id, + location=location, + operation_id=operation_id, + config=config, + ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py new file mode 100644 index 0000000000000..f6e2230ddc711 --- /dev/null +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -0,0 +1,441 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Vertex AI Agent Engine operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from airflow.providers.common.compat.sdk import conf +from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import ( + AgentEngineHook, + extract_operation_id, + serialize_value, +) +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator +from airflow.providers.google.cloud.triggers.vertex_ai import AgentEngineQueryJobTrigger + +if TYPE_CHECKING: + from vertexai._genai import types + + from airflow.providers.common.compat.sdk import Context + + +def _serialize_agent_engine(agent_engine: types.AgentEngine) -> dict[str, Any]: + api_resource = getattr(agent_engine, "api_resource", None) + if api_resource is not None: + return serialize_value(api_resource) + return serialize_value(agent_engine) + + +class CreateAgentEngineOperator(GoogleCloudBaseOperator): + """ + Create a Vertex AI Agent Engine. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param agent: Optional. The agent object to deploy. + :param config: Optional. Configuration for the Agent Engine. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term credentials. + """ + + template_fields = ( + "project_id", + "location", + "agent", + "config", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + project_id: str, + location: str, + agent: Any | None = None, + config: types.AgentEngineConfigOrDict | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.agent = agent + self.config = config + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> AgentEngineHook: + return AgentEngineHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + def execute(self, context: Context) -> dict[str, Any]: + self.log.info("Creating Agent Engine.") + agent_engine = self.hook.create_agent_engine( + project_id=self.project_id, + location=self.location, + agent=self.agent, + config=self.config, + ) + result = _serialize_agent_engine(agent_engine) + self.log.info("Agent Engine was created.") + return result + + +class GetAgentEngineOperator(GoogleCloudBaseOperator): + """ + Get a Vertex AI Agent Engine. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param agent_engine_id: Required. The Agent Engine ID. + :param config: Optional. Configuration for getting the Agent Engine. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term credentials. + """ + + template_fields = ( + "project_id", + "location", + "agent_engine_id", + "config", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + project_id: str, + location: str, + agent_engine_id: str, + config: types.GetAgentEngineConfigOrDict | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.agent_engine_id = agent_engine_id + self.config = config + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> AgentEngineHook: + return AgentEngineHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + def execute(self, context: Context) -> dict[str, Any]: + self.log.info("Getting Agent Engine %s.", self.agent_engine_id) + agent_engine = self.hook.get_agent_engine( + project_id=self.project_id, + location=self.location, + agent_engine_id=self.agent_engine_id, + config=self.config, + ) + result = _serialize_agent_engine(agent_engine) + self.log.info("Agent Engine %s was retrieved.", self.agent_engine_id) + return result + + +class RunQueryJobOperator(GoogleCloudBaseOperator): + """ + Run a query job on a Vertex AI Agent Engine. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param agent_engine_id: Required. The Agent Engine ID. + :param config: Optional. Configuration for the query job (``query``, ``output_gcs_uri``). + :param check_config: Optional. Configuration for checking the query job. + :param wait_for_completion: Whether to wait until the query job completes. + :param poll_interval: Time, in seconds, to wait between checks. + :param timeout: Optional timeout, in seconds. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term credentials. + :param deferrable: Run operator in the deferrable mode. + """ + + template_fields = ( + "project_id", + "location", + "agent_engine_id", + "config", + "check_config", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + project_id: str, + location: str, + agent_engine_id: str, + config: types.RunQueryJobAgentEngineConfigOrDict | None = None, + check_config: types.CheckQueryJobAgentEngineConfigOrDict | None = None, + wait_for_completion: bool = True, + poll_interval: float = 30, + timeout: float | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.agent_engine_id = agent_engine_id + self.config = config + self.check_config = check_config + self.wait_for_completion = wait_for_completion + self.poll_interval = poll_interval + self.timeout = timeout + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.deferrable = deferrable + + @cached_property + def hook(self) -> AgentEngineHook: + return AgentEngineHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + def execute(self, context: Context) -> dict[str, Any]: + self.log.info("Running query job on Agent Engine %s.", self.agent_engine_id) + query_job = self.hook.run_query_job( + project_id=self.project_id, + location=self.location, + agent_engine_id=self.agent_engine_id, + config=self.config, + ) + result = serialize_value(query_job) + self.log.info("Query job was started on Agent Engine %s.", self.agent_engine_id) + if not self.wait_for_completion: + return result + + operation_name = getattr(query_job, "job_name", None) + if not operation_name: + raise RuntimeError("Agent Engine query job did not include an operation name.") + operation_id = extract_operation_id(operation_name) + + if self.deferrable: + return self.defer( + trigger=AgentEngineQueryJobTrigger( + project_id=self.project_id, + location=self.location, + operation_id=operation_id, + config=self.check_config, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + poll_interval=self.poll_interval, + timeout=self.timeout, + ), + method_name="execute_complete", + ) + + completed_job = self.hook.wait_for_query_agent_engine_job( + project_id=self.project_id, + location=self.location, + operation_id=operation_id, + config=self.check_config, + poll_interval=self.poll_interval, + timeout=self.timeout, + ) + self.log.info("Agent Engine query job %s completed.", operation_name) + return serialize_value(completed_job) + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: + if event is None: + raise RuntimeError("No event received in trigger callback") + if event["status"] == "success": + self.log.info("Agent Engine query job completed.") + return event["query_job"] + if event["status"] == "timeout": + raise TimeoutError(event["message"]) + raise RuntimeError(event["message"]) + + +class UpdateAgentEngineOperator(GoogleCloudBaseOperator): + """ + Update a Vertex AI Agent Engine. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param agent_engine_id: Required. The Agent Engine ID. + :param agent: Optional. The updated agent object to deploy. + :param config: Required. Configuration for the Agent Engine update. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term credentials. + """ + + template_fields = ( + "project_id", + "location", + "agent_engine_id", + "agent", + "config", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + project_id: str, + location: str, + agent_engine_id: str, + config: types.AgentEngineConfigOrDict, + agent: Any | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.agent_engine_id = agent_engine_id + self.agent = agent + self.config = config + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> AgentEngineHook: + return AgentEngineHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + def execute(self, context: Context) -> dict[str, Any]: + self.log.info("Updating Agent Engine %s.", self.agent_engine_id) + agent_engine = self.hook.update_agent_engine( + project_id=self.project_id, + location=self.location, + agent_engine_id=self.agent_engine_id, + agent=self.agent, + config=self.config, + ) + result = _serialize_agent_engine(agent_engine) + self.log.info("Agent Engine %s was updated.", self.agent_engine_id) + return result + + +class DeleteAgentEngineOperator(GoogleCloudBaseOperator): + """ + Delete a Vertex AI Agent Engine. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param agent_engine_id: Required. The Agent Engine ID. + :param force: Optional. Whether to delete child resources. + :param config: Optional. Additional deletion configuration. + :param wait_for_completion: Whether to wait until the delete operation completes. + :param poll_interval: Time, in seconds, to wait between checks. + :param timeout: Optional timeout, in seconds. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term credentials. + """ + + template_fields = ( + "project_id", + "location", + "agent_engine_id", + "force", + "config", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + project_id: str, + location: str, + agent_engine_id: str, + force: bool | None = None, + config: types.DeleteAgentEngineConfigOrDict | None = None, + wait_for_completion: bool = True, + poll_interval: float = 30, + timeout: float | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.agent_engine_id = agent_engine_id + self.force = force + self.config = config + self.wait_for_completion = wait_for_completion + self.poll_interval = poll_interval + self.timeout = timeout + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> AgentEngineHook: + return AgentEngineHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + def execute(self, context: Context) -> dict[str, Any]: + self.log.info("Deleting Agent Engine %s.", self.agent_engine_id) + operation = self.hook.delete_agent_engine( + project_id=self.project_id, + location=self.location, + agent_engine_id=self.agent_engine_id, + force=self.force, + config=self.config, + ) + result = serialize_value(operation) + if not self.wait_for_completion: + return result + + operation_name = getattr(operation, "name", None) + if not operation_name: + raise RuntimeError("Delete Agent Engine operation did not include an operation name.") + operation_id = extract_operation_id(operation_name) + + if getattr(operation, "done", False): + self.log.info("Agent Engine %s was deleted.", self.agent_engine_id) + return result + + self.hook.wait_for_agent_engine_operation( + project_id=self.project_id, + location=self.location, + operation_id=operation_id, + poll_interval=self.poll_interval, + timeout=self.timeout, + ) + self.log.info("Agent Engine %s was deleted.", self.agent_engine_id) + return result diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py index d39e96734fee5..6e7ecff02fc60 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +import asyncio +import time from collections.abc import AsyncIterator, Sequence from functools import cached_property from typing import TYPE_CHECKING, Any @@ -29,6 +31,7 @@ ) from airflow.providers.common.compat.sdk import AirflowException +from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineAsyncHook, serialize_value from airflow.providers.google.cloud.hooks.vertex_ai.batch_prediction_job import BatchPredictionJobAsyncHook from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobAsyncHook from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import ( @@ -39,6 +42,7 @@ if TYPE_CHECKING: from proto import Message + from vertexai._genai import types as vertexai_types class BaseVertexAIJobTrigger(BaseTrigger): @@ -126,6 +130,118 @@ def _serialize_job(self, job: Any) -> Any: return self.job_serializer_class.to_dict(job) +class AgentEngineQueryJobTrigger(BaseTrigger): + """Trigger that waits until a Vertex AI Agent Engine query job completes.""" + + def __init__( + self, + project_id: str, + location: str, + operation_id: str, + config: vertexai_types.CheckQueryJobAgentEngineConfigOrDict | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + poll_interval: float = 30, + timeout: float | None = None, + ): + super().__init__() + self.project_id = project_id + self.location = location + self.operation_id = operation_id + self.config = config + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.poll_interval = poll_interval + self.timeout = timeout + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineQueryJobTrigger", + { + "project_id": self.project_id, + "location": self.location, + "operation_id": self.operation_id, + "config": serialize_value(self.config), + "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, + "poll_interval": self.poll_interval, + "timeout": self.timeout, + }, + ) + + @cached_property + def async_hook(self) -> AgentEngineAsyncHook: + return AgentEngineAsyncHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + start_time = time.monotonic() + try: + while True: + query_job = await self.async_hook.check_query_agent_engine_job( + project_id=self.project_id, + location=self.location, + operation_id=self.operation_id, + config=self.config, + ) + status = getattr(query_job, "status", None) + serialized_query_job = serialize_value(query_job) + if status == "SUCCESS": + yield TriggerEvent( + { + "status": "success", + "message": "Agent Engine query job completed", + "query_job": serialized_query_job, + } + ) + return + if status == "FAILED": + yield TriggerEvent( + { + "status": "error", + "message": f"Agent Engine query job {self.operation_id} failed.", + "query_job": serialized_query_job, + } + ) + return + if status not in (None, "RUNNING"): + yield TriggerEvent( + { + "status": "error", + "message": ( + f"Agent Engine query job {self.operation_id} completed with " + f"unexpected status {status}." + ), + "query_job": serialized_query_job, + } + ) + return + + if self.timeout is not None and time.monotonic() - start_time >= self.timeout: + yield TriggerEvent( + { + "status": "timeout", + "message": f"Timed out waiting for Agent Engine query job {self.operation_id}", + "query_job": serialized_query_job, + } + ) + return + + self.log.info("Waiting for Agent Engine query job %s to complete.", self.operation_id) + await asyncio.sleep(self.poll_interval) + except Exception as err: + self.log.exception("Exception occurred while waiting for Agent Engine query job.") + yield TriggerEvent( + { + "status": "error", + "message": f"Failed while polling Agent Engine query job: {err}", + "query_job": {"operation_id": self.operation_id}, + } + ) + + class CreateHyperparameterTuningJobTrigger(BaseVertexAIJobTrigger): """CreateHyperparameterTuningJobTrigger run on the trigger worker to perform create operation.""" diff --git a/providers/google/src/airflow/providers/google/get_provider_info.py b/providers/google/src/airflow/providers/google/get_provider_info.py index 61571bc6831f3..92966a4b3679b 100644 --- a/providers/google/src/airflow/providers/google/get_provider_info.py +++ b/providers/google/src/airflow/providers/google/get_provider_info.py @@ -672,6 +672,7 @@ def get_provider_info(): "airflow.providers.google.cloud.operators.vertex_ai.generative_model", "airflow.providers.google.cloud.operators.vertex_ai.feature_store", "airflow.providers.google.cloud.operators.vertex_ai.ray", + "airflow.providers.google.cloud.operators.vertex_ai.agent_engine", ], }, { @@ -1043,6 +1044,7 @@ def get_provider_info(): "airflow.providers.google.cloud.hooks.vertex_ai.generative_model", "airflow.providers.google.cloud.hooks.vertex_ai.prediction_service", "airflow.providers.google.cloud.hooks.vertex_ai.feature_store", + "airflow.providers.google.cloud.hooks.vertex_ai.agent_engine", "airflow.providers.google.cloud.hooks.vertex_ai.ray", ], }, diff --git a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py new file mode 100644 index 0000000000000..19756f1e50d15 --- /dev/null +++ b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py @@ -0,0 +1,205 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow Dag for Google Vertex AI Agent Engine operations. +""" + +from __future__ import annotations + +import json +import os +from datetime import datetime +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from vertexai._genai import types as vertexai_types + +from airflow.providers.google.cloud.operators.vertex_ai.agent_engine import ( + CreateAgentEngineOperator, + DeleteAgentEngineOperator, + GetAgentEngineOperator, + RunQueryJobOperator, + UpdateAgentEngineOperator, +) + +try: + from airflow.sdk import DAG, TriggerRule +except ImportError: + # Compatibility for Airflow < 3.1 + from airflow.models.dag import DAG # type: ignore[attr-defined,no-redef,assignment] + from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef,attr-defined] + +DAG_ID = "vertex_ai_agent_engine_operations" +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") + + +def _get_env(name: str, default: str = "") -> str: + return os.environ.get(name) or os.environ.get(f"AIRFLOW_VAR_{name}", default) + + +def _get_json_env(name: str, default: dict[str, str]) -> dict[str, str]: + value = os.environ.get(name) + return json.loads(value) if value else default + + +def _get_container_env_vars() -> dict[str, str]: + return _get_json_env( + "SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_CONTAINER_ENV_VARS", + {}, + ) + + +LOCATION = _get_env("GCP_REGION", "us-central1") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or _get_env("GCP_PROJECT_ID", "default") +CONTAINER_URI = os.environ.get("SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_CONTAINER_URI") or _get_env( + "GCP_AGENT_ENGINE_CONTAINER_URI", + "us-central1-docker.pkg.dev/example-project/example-repository/airflow-hello-agent:latest", +) +CONTAINER_ENV_VARS = _get_container_env_vars() +QUERY_STR = os.environ.get("SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_QUERY") or json.dumps( + { + "input": "hello from Airflow", + } +) +QUERY_OUTPUT_GCS_URI = os.environ.get( + "SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_QUERY_OUTPUT_GCS_URI", + f"gs://{PROJECT_ID}-agent-engine-query-output/query-output/", +) + +AGENT_ENGINE_ID = "{{ task_instance.xcom_pull(task_ids='create_agent_engine')['name'].split('/')[-1] }}" +DISPLAY_NAME = f"airflow-agent-engine-{ENV_ID}" + +QUERY_CONFIG: vertexai_types.RunQueryJobAgentEngineConfigDict = { + "query": QUERY_STR, + "output_gcs_uri": QUERY_OUTPUT_GCS_URI, +} + + +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "vertex_ai", "agent_engine"], +) as dag: + # [START how_to_cloud_vertex_ai_create_agent_engine_operator] + create_agent_engine = CreateAgentEngineOperator( + task_id="create_agent_engine", + project_id=PROJECT_ID, + location=LOCATION, + config={ + "display_name": DISPLAY_NAME, + "description": "Airflow system test Agent Engine", + "agent_framework": "custom", + "min_instances": 0, + "max_instances": 1, + "resource_limits": {"cpu": "1", "memory": "1Gi"}, + "container_spec": {"image_uri": CONTAINER_URI}, + **({"env_vars": CONTAINER_ENV_VARS} if CONTAINER_ENV_VARS else {}), + "class_methods": [ + { + "name": "query", + "api_mode": "", + }, + ], + }, + ) + # [END how_to_cloud_vertex_ai_create_agent_engine_operator] + + # [START how_to_cloud_vertex_ai_get_agent_engine_operator] + get_agent_engine = GetAgentEngineOperator( + task_id="get_agent_engine", + project_id=PROJECT_ID, + location=LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + ) + # [END how_to_cloud_vertex_ai_get_agent_engine_operator] + + # [START how_to_cloud_vertex_ai_run_query_job_operator] + run_query_job = RunQueryJobOperator( + task_id="run_query_job", + project_id=PROJECT_ID, + location=LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=QUERY_CONFIG, + check_config={"retrieve_result": True}, + poll_interval=10, + timeout=900, + ) + # [END how_to_cloud_vertex_ai_run_query_job_operator] + + # [START how_to_cloud_vertex_ai_run_query_job_operator_deferrable] + run_query_job_deferrable = RunQueryJobOperator( + task_id="run_query_job_deferrable", + project_id=PROJECT_ID, + location=LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=QUERY_CONFIG, + check_config={"retrieve_result": True}, + poll_interval=10, + timeout=900, + deferrable=True, + ) + # [END how_to_cloud_vertex_ai_run_query_job_operator_deferrable] + + # [START how_to_cloud_vertex_ai_update_agent_engine_operator] + update_agent_engine = UpdateAgentEngineOperator( + task_id="update_agent_engine", + project_id=PROJECT_ID, + location=LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config={ + "display_name": f"{DISPLAY_NAME}-updated", + "description": "Updated Airflow system test Agent Engine", + }, + ) + # [END how_to_cloud_vertex_ai_update_agent_engine_operator] + + # [START how_to_cloud_vertex_ai_delete_agent_engine_operator] + delete_agent_engine = DeleteAgentEngineOperator( + task_id="delete_agent_engine", + project_id=PROJECT_ID, + location=LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + force=True, + trigger_rule=TriggerRule.ALL_DONE, + ) + # [END how_to_cloud_vertex_ai_delete_agent_engine_operator] + + ( + create_agent_engine + >> get_agent_engine + >> run_query_job + >> run_query_job_deferrable + >> update_agent_engine + >> delete_agent_engine + ) + + # ### Everything below this line is not part of example ### + # ### Just for system tests purpose ### + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the Dag + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example Dag with pytest (see: contributing-docs/testing/system_tests.rst) +test_run = get_test_run(dag) diff --git a/providers/google/tests/system/google/cloud/vertex_ai/resources/agent_engine/__init__.py b/providers/google/tests/system/google/cloud/vertex_ai/resources/agent_engine/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/google/tests/system/google/cloud/vertex_ai/resources/agent_engine/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py new file mode 100644 index 0000000000000..ed4f0c7220188 --- /dev/null +++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py @@ -0,0 +1,426 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineAsyncHook, AgentEngineHook + +from unit.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id + +BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" +AGENT_ENGINE_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.agent_engine.{}" + +TEST_GCP_CONN_ID = "test-gcp-conn-id" +GCP_PROJECT = "test-project" +GCP_LOCATION = "us-central1" +AGENT_ENGINE_ID = "123" +AGENT_ENGINE_NAME = "projects/test-project/locations/us-central1/reasoningEngines/123" +OPERATION_NAME = "projects/test-project/locations/us-central1/operations/delete-123" +QUERY_OPERATION_NAME = "projects/test-project/locations/us-central1/operations/query-123" +OPERATION_ID = "delete-123" +QUERY_OPERATION_ID = "query-123" +CONFIG = {"display_name": "test-agent-engine"} +QUERY_CONFIG = {"query": "hello", "output_gcs_uri": "gs://test-bucket/query-output/"} +CHECK_QUERY_CONFIG = {"retrieve_result": True} + + +class TestAgentEngineHookWithDefaultProjectId: + def setup_method(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id + ): + self.hook = AgentEngineHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(AGENT_ENGINE_STRING.format("Client"), autospec=True) + def test_get_agent_engine_client(self, mock_client): + self.hook.get_credentials = mock.Mock(return_value=mock.sentinel.credentials, spec=()) + + result = self.hook.get_agent_engine_client(project_id=GCP_PROJECT, location=GCP_LOCATION) + + mock_client.assert_called_once_with( + project=GCP_PROJECT, + location=GCP_LOCATION, + credentials=mock.sentinel.credentials, + ) + assert result == mock_client.return_value.agent_engines + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_create_agent_engine(self, mock_get_client): + result = self.hook.create_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + config=CONFIG, + ) + + mock_get_client.assert_called_once_with(self.hook, project_id=GCP_PROJECT, location=GCP_LOCATION) + mock_get_client.return_value.create.assert_called_once_with( + agent=None, + config=CONFIG, + ) + assert result == mock_get_client.return_value.create.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_get_agent_engine(self, mock_get_client): + result = self.hook.get_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + ) + + mock_get_client.return_value.get.assert_called_once_with(name=AGENT_ENGINE_NAME, config=None) + assert result == mock_get_client.return_value.get.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_get_agent_engine_with_config(self, mock_get_client): + result = self.hook.get_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=CONFIG, + ) + + mock_get_client.return_value.get.assert_called_once_with(name=AGENT_ENGINE_NAME, config=CONFIG) + assert result == mock_get_client.return_value.get.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_run_query_job(self, mock_get_client): + result = self.hook.run_query_job( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=QUERY_CONFIG, + ) + + mock_get_client.return_value.run_query_job.assert_called_once_with( + name=AGENT_ENGINE_NAME, + config=QUERY_CONFIG, + ) + assert result == mock_get_client.return_value.run_query_job.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_check_query_agent_engine_job(self, mock_get_client): + result = self.hook.check_query_agent_engine_job( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=QUERY_OPERATION_ID, + config=CHECK_QUERY_CONFIG, + ) + + mock_get_client.return_value.check_query_job.assert_called_once_with( + name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + ) + assert result == mock_get_client.return_value.check_query_job.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"), autospec=True) + def test_wait_for_query_agent_engine_job_returns_when_successful(self, mock_check_query_job): + mock_check_query_job.return_value.status = "SUCCESS" + + result = self.hook.wait_for_query_agent_engine_job( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=QUERY_OPERATION_ID, + config=CHECK_QUERY_CONFIG, + ) + + mock_check_query_job.assert_called_once_with( + self.hook, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=QUERY_OPERATION_ID, + config=CHECK_QUERY_CONFIG, + ) + assert result == mock_check_query_job.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"), autospec=True) + def test_wait_for_query_agent_engine_job_polls_until_success(self, mock_check_query_job, mock_sleep): + running_job = mock.Mock(status="RUNNING") + success_job = mock.Mock(status="SUCCESS") + mock_check_query_job.side_effect = [running_job, running_job, success_job] + + result = self.hook.wait_for_query_agent_engine_job( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=QUERY_OPERATION_ID, + config=CHECK_QUERY_CONFIG, + poll_interval=10, + ) + + assert result is success_job + assert mock_check_query_job.call_count == 3 + assert mock_sleep.call_count == 2 + mock_sleep.assert_called_with(10) + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"), autospec=True) + def test_wait_for_query_agent_engine_job_raises_on_failed_status(self, mock_check_query_job): + mock_check_query_job.return_value.status = "FAILED" + + with pytest.raises(RuntimeError, match="Agent Engine query job .* failed"): + self.hook.wait_for_query_agent_engine_job( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=QUERY_OPERATION_ID, + config=CHECK_QUERY_CONFIG, + ) + + @mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"), autospec=True) + def test_wait_for_query_agent_engine_job_raises_on_unexpected_status( + self, mock_check_query_job, mock_sleep + ): + mock_check_query_job.return_value.status = "CANCELLED" + + with pytest.raises( + RuntimeError, + match=f"Agent Engine query job {QUERY_OPERATION_NAME} completed with unexpected status CANCELLED.", + ): + self.hook.wait_for_query_agent_engine_job( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=QUERY_OPERATION_ID, + config=CHECK_QUERY_CONFIG, + ) + + mock_sleep.assert_not_called() + + @mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("time.monotonic"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"), autospec=True) + def test_wait_for_query_agent_engine_job_times_out( + self, mock_check_query_job, mock_monotonic, mock_sleep + ): + mock_check_query_job.return_value.status = "RUNNING" + mock_monotonic.side_effect = [1, 3] + + with pytest.raises(TimeoutError, match="Timed out waiting for Agent Engine query job"): + self.hook.wait_for_query_agent_engine_job( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=QUERY_OPERATION_ID, + config=CHECK_QUERY_CONFIG, + timeout=1, + ) + + mock_sleep.assert_not_called() + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_update_agent_engine(self, mock_get_client): + result = self.hook.update_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=CONFIG, + ) + + mock_get_client.return_value.update.assert_called_once_with( + name=AGENT_ENGINE_NAME, + agent=None, + config=CONFIG, + ) + assert result == mock_get_client.return_value.update.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_delete_agent_engine(self, mock_get_client): + result = self.hook.delete_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + force=True, + config=CONFIG, + ) + + mock_get_client.return_value.delete.assert_called_once_with( + name=AGENT_ENGINE_NAME, + force=True, + config=CONFIG, + ) + assert result == mock_get_client.return_value.delete.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("google.auth.transport.requests.AuthorizedSession"), autospec=True) + def test_get_agent_engine_operation(self, mock_session): + self.hook.get_credentials = mock.Mock(return_value=mock.sentinel.credentials, spec=()) + mock_session.return_value.get.return_value.json.return_value = {"name": OPERATION_NAME, "done": True} + + result = self.hook.get_agent_engine_operation( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=OPERATION_ID, + ) + + mock_session.assert_called_once_with(mock.sentinel.credentials) + mock_session.return_value.get.assert_called_once_with( + f"https://{GCP_LOCATION}-aiplatform.googleapis.com/v1beta1/{OPERATION_NAME}", + timeout=60.0, + ) + mock_session.return_value.get.return_value.raise_for_status.assert_called_once_with() + assert result == {"name": OPERATION_NAME, "done": True} + + @mock.patch(AGENT_ENGINE_STRING.format("google.auth.transport.requests.AuthorizedSession"), autospec=True) + def test_get_agent_engine_operation_with_request_timeout(self, mock_session): + self.hook.get_credentials = mock.Mock(return_value=mock.sentinel.credentials, spec=()) + mock_session.return_value.get.return_value.json.return_value = {"name": OPERATION_NAME, "done": True} + + self.hook.get_agent_engine_operation( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=OPERATION_ID, + request_timeout=10, + ) + + mock_session.return_value.get.assert_called_once_with( + f"https://{GCP_LOCATION}-aiplatform.googleapis.com/v1beta1/{OPERATION_NAME}", + timeout=10, + ) + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_operation"), autospec=True) + def test_wait_for_agent_engine_operation_returns_when_done(self, mock_get_operation): + mock_get_operation.return_value = {"name": OPERATION_NAME, "done": True} + + self.hook.wait_for_agent_engine_operation( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=OPERATION_ID, + ) + + mock_get_operation.assert_called_once_with( + self.hook, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=OPERATION_ID, + ) + + @mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_operation"), autospec=True) + def test_wait_for_agent_engine_operation_polls_until_done(self, mock_get_operation, mock_sleep): + running_operation = {"name": OPERATION_NAME, "done": False} + done_operation = {"name": OPERATION_NAME, "done": True} + mock_get_operation.side_effect = [running_operation, running_operation, done_operation] + + self.hook.wait_for_agent_engine_operation( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=OPERATION_ID, + poll_interval=10, + ) + + assert mock_get_operation.call_count == 3 + assert mock_sleep.call_count == 2 + mock_sleep.assert_called_with(10) + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_operation"), autospec=True) + def test_wait_for_agent_engine_operation_raises_on_error(self, mock_get_operation): + mock_get_operation.return_value = {"name": OPERATION_NAME, "done": True, "error": {"message": "boom"}} + + with pytest.raises(RuntimeError, match="Agent Engine operation .* failed"): + self.hook.wait_for_agent_engine_operation( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=OPERATION_ID, + ) + + @mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("time.monotonic"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_operation"), autospec=True) + def test_wait_for_agent_engine_operation_times_out(self, mock_get_operation, mock_monotonic, mock_sleep): + mock_get_operation.return_value = {"name": OPERATION_NAME, "done": False} + mock_monotonic.side_effect = [1, 3] + + with pytest.raises(TimeoutError, match="Timed out waiting for Agent Engine operation"): + self.hook.wait_for_agent_engine_operation( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=OPERATION_ID, + timeout=1, + ) + + mock_sleep.assert_not_called() + + +class TestAgentEngineAsyncHook: + def setup_method(self): + with mock.patch( + BASE_STRING.format("GoogleBaseAsyncHook.__init__"), + return_value=None, + ): + self.hook = AgentEngineAsyncHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(AGENT_ENGINE_STRING.format("ClientSession"), autospec=True) + @pytest.mark.parametrize( + ("valid_credentials", "expected_refresh_calls"), + [ + (False, 1), + (True, 0), + ], + ) + @pytest.mark.asyncio + async def test_get_agent_engine_operation( + self, mock_client_session, valid_credentials, expected_refresh_calls + ): + sync_hook = mock.Mock(spec=AgentEngineHook) + sync_hook.project_id = GCP_PROJECT + credentials = mock.Mock(valid=valid_credentials, token="test-token") + sync_hook.get_credentials.return_value = credentials + self.hook.get_sync_hook = mock.AsyncMock(return_value=sync_hook) + session = mock.MagicMock() + mock_client_session.return_value.__aenter__.return_value = session + response = mock.Mock() + response.json = mock.AsyncMock(return_value={"name": OPERATION_NAME, "done": True}) + request = mock.MagicMock() + request.__aenter__.return_value = response + session.get.return_value = request + + result = await self.hook.get_agent_engine_operation( + location=GCP_LOCATION, + operation_id=OPERATION_ID, + request_timeout=10, + ) + + assert credentials.refresh.call_count == expected_refresh_calls + session.get.assert_called_once_with( + f"https://{GCP_LOCATION}-aiplatform.googleapis.com/v1beta1/{OPERATION_NAME}", + headers={"Authorization": "Bearer test-token"}, + timeout=10, + ) + response.raise_for_status.assert_called_once_with() + sync_hook.get_agent_engine_operation.assert_not_called() + assert result == {"name": OPERATION_NAME, "done": True} + + @pytest.mark.asyncio + async def test_check_query_agent_engine_job_calls_sync_hook(self): + sync_hook = mock.Mock(spec=AgentEngineHook) + sync_hook.check_query_agent_engine_job.return_value = mock.sentinel.query_job + self.hook.get_sync_hook = mock.AsyncMock(return_value=sync_hook) + + result = await self.hook.check_query_agent_engine_job( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=QUERY_OPERATION_ID, + config=CHECK_QUERY_CONFIG, + ) + + sync_hook.check_query_agent_engine_job.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=QUERY_OPERATION_ID, + config=CHECK_QUERY_CONFIG, + ) + assert result == mock.sentinel.query_job diff --git a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py new file mode 100644 index 0000000000000..4f67bba9fe5f6 --- /dev/null +++ b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py @@ -0,0 +1,439 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.providers.common.compat.sdk import TaskDeferred +from airflow.providers.google.cloud.operators.vertex_ai.agent_engine import ( + CreateAgentEngineOperator, + DeleteAgentEngineOperator, + GetAgentEngineOperator, + RunQueryJobOperator, + UpdateAgentEngineOperator, +) + +AGENT_ENGINE_PATH = "airflow.providers.google.cloud.operators.vertex_ai.agent_engine.{}" + +TASK_ID = "test_task_id" +GCP_PROJECT = "test-project" +GCP_LOCATION = "us-central1" +GCP_CONN_ID = "test-conn" +IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] +AGENT_ENGINE_ID = "123" +AGENT_ENGINE_NAME = "projects/test-project/locations/us-central1/reasoningEngines/123" +CONFIG = {"display_name": "test-agent-engine"} +QUERY_CONFIG = {"query": "hello", "output_gcs_uri": "gs://test-bucket/query-output/"} +CHECK_QUERY_CONFIG = {"retrieve_result": True} +OPERATION = {"name": "operations/delete-123", "done": False} +QUERY_OPERATION_NAME = "operations/query-123" +OPERATION_ID = "delete-123" +QUERY_OPERATION_ID = "query-123" + + +class FakeModel: + def __init__(self, payload): + self.payload = payload + for key, value in payload.items(): + setattr(self, key, value) + + def model_dump(self, mode="json"): + return self.payload + + +class FakeAgentEngine: + def __init__(self, payload): + self.api_resource = FakeModel(payload) + + +@pytest.fixture +def context(): + return {"ti": mock.Mock(spec_set=["xcom_push"])} + + +def assert_hook_created(mock_hook): + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + +class TestCreateAgentEngineOperator: + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute(self, mock_hook, context): + mock_hook.return_value.create_agent_engine.return_value = FakeAgentEngine( + {"name": AGENT_ENGINE_NAME, "display_name": "test-agent-engine"} + ) + op = CreateAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + config=CONFIG, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + assert_hook_created(mock_hook) + mock_hook.return_value.create_agent_engine.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent=None, + config=CONFIG, + ) + assert result == {"name": AGENT_ENGINE_NAME, "display_name": "test-agent-engine"} + + +class TestGetAgentEngineOperator: + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute(self, mock_hook, context): + mock_hook.return_value.get_agent_engine.return_value = FakeAgentEngine({"name": AGENT_ENGINE_NAME}) + op = GetAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + mock_hook.return_value.get_agent_engine.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=None, + ) + assert result == {"name": AGENT_ENGINE_NAME} + + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_with_config(self, mock_hook, context): + mock_hook.return_value.get_agent_engine.return_value = FakeAgentEngine({"name": AGENT_ENGINE_NAME}) + op = GetAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=CONFIG, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + op.execute(context=context) + + mock_hook.return_value.get_agent_engine.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=CONFIG, + ) + + +class TestRunQueryJobOperator: + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute(self, mock_hook, context): + run_result_payload = { + "job_name": "operations/query-123", + "input_gcs_uri": "gs://test-bucket/query-output/input.json", + "output_gcs_uri": "gs://test-bucket/query-output/output.json", + } + query_result_payload = { + "operation_name": QUERY_OPERATION_NAME, + "output_gcs_uri": "gs://test-bucket/query-output/output.json", + "status": "SUCCESS", + "result": "done", + } + mock_hook.return_value.run_query_job.return_value = FakeModel(run_result_payload) + mock_hook.return_value.wait_for_query_agent_engine_job.return_value = FakeModel(query_result_payload) + op = RunQueryJobOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=QUERY_CONFIG, + check_config=CHECK_QUERY_CONFIG, + poll_interval=1, + timeout=60, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + mock_hook.return_value.run_query_job.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=QUERY_CONFIG, + ) + mock_hook.return_value.wait_for_query_agent_engine_job.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=QUERY_OPERATION_ID, + config=CHECK_QUERY_CONFIG, + poll_interval=1, + timeout=60, + ) + assert result == query_result_payload + + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_without_wait(self, mock_hook, context): + result_payload = { + "job_name": "operations/query-123", + "input_gcs_uri": "gs://test-bucket/query-output/input.json", + "output_gcs_uri": "gs://test-bucket/query-output/output.json", + } + mock_hook.return_value.run_query_job.return_value = FakeModel(result_payload) + op = RunQueryJobOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + config=CHECK_QUERY_CONFIG, + agent_engine_id=AGENT_ENGINE_ID, + wait_for_completion=False, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + mock_hook.return_value.wait_for_query_agent_engine_job.assert_not_called() + assert result == result_payload + + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineQueryJobTrigger"), autospec=True) + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_deferrable(self, mock_hook, mock_trigger, context): + mock_hook.return_value.run_query_job.return_value = FakeModel({"job_name": QUERY_OPERATION_NAME}) + op = RunQueryJobOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=CHECK_QUERY_CONFIG, + check_config=CHECK_QUERY_CONFIG, + poll_interval=1, + timeout=60, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + deferrable=True, + ) + + with pytest.raises(TaskDeferred): + op.execute(context=context) + + mock_hook.return_value.wait_for_query_agent_engine_job.assert_not_called() + mock_trigger.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=QUERY_OPERATION_ID, + config=CHECK_QUERY_CONFIG, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + poll_interval=1, + timeout=60, + ) + + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_raises_when_query_job_has_no_name(self, mock_hook, context): + mock_hook.return_value.run_query_job.return_value = FakeModel({}) + op = RunQueryJobOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + ) + + with pytest.raises(RuntimeError, match="Agent Engine query job did not include an operation name."): + op.execute(context=context) + + mock_hook.return_value.wait_for_query_agent_engine_job.assert_not_called() + + def test_execute_complete_success(self, context): + op = RunQueryJobOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + ) + query_job = {"operation_name": QUERY_OPERATION_NAME, "status": "SUCCESS"} + + result = op.execute_complete( + context=context, + event={"status": "success", "message": "done", "query_job": query_job}, + ) + + assert result == query_job + + def test_execute_complete_error(self, context): + op = RunQueryJobOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + ) + + with pytest.raises(RuntimeError, match="boom"): + op.execute_complete(context=context, event={"status": "error", "message": "boom"}) + + def test_execute_complete_timeout(self, context): + op = RunQueryJobOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + ) + + with pytest.raises(TimeoutError, match="timed out"): + op.execute_complete(context=context, event={"status": "timeout", "message": "timed out"}) + + def test_execute_complete_without_event(self, context): + op = RunQueryJobOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + ) + + with pytest.raises(RuntimeError, match="No event received in trigger callback"): + op.execute_complete(context=context) + + +class TestUpdateAgentEngineOperator: + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute(self, mock_hook, context): + mock_hook.return_value.update_agent_engine.return_value = FakeAgentEngine( + {"name": AGENT_ENGINE_NAME, "display_name": "updated-agent-engine"} + ) + op = UpdateAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=CONFIG, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + mock_hook.return_value.update_agent_engine.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + agent=None, + config=CONFIG, + ) + assert result == {"name": AGENT_ENGINE_NAME, "display_name": "updated-agent-engine"} + + +class TestDeleteAgentEngineOperator: + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_without_wait(self, mock_hook, context): + mock_hook.return_value.delete_agent_engine.return_value = FakeModel(OPERATION) + op = DeleteAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + force=True, + config=CONFIG, + wait_for_completion=False, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + mock_hook.return_value.delete_agent_engine.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + force=True, + config=CONFIG, + ) + mock_hook.return_value.wait_for_agent_engine_operation.assert_not_called() + assert result == OPERATION + + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_waits_until_deleted(self, mock_hook, context): + mock_hook.return_value.delete_agent_engine.return_value = FakeModel(OPERATION) + op = DeleteAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + wait_for_completion=True, + poll_interval=1, + timeout=60, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + mock_hook.return_value.wait_for_agent_engine_operation.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=OPERATION_ID, + poll_interval=1, + timeout=60, + ) + assert result == OPERATION + + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_does_not_wait_when_delete_operation_is_done(self, mock_hook, context): + operation = {"name": "operations/delete-123", "done": True} + mock_hook.return_value.delete_agent_engine.return_value = FakeModel(operation) + op = DeleteAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + wait_for_completion=True, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + mock_hook.return_value.wait_for_agent_engine_operation.assert_not_called() + assert result == operation + + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_raises_when_delete_operation_has_no_name(self, mock_hook, context): + mock_hook.return_value.delete_agent_engine.return_value = FakeModel({"done": False}) + op = DeleteAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + wait_for_completion=True, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + with pytest.raises( + RuntimeError, match=r"Delete Agent Engine operation did not include an operation name\." + ): + op.execute(context=context) + + mock_hook.return_value.wait_for_agent_engine_operation.assert_not_called() diff --git a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py new file mode 100644 index 0000000000000..dfb2b418be5d9 --- /dev/null +++ b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py @@ -0,0 +1,207 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.providers.google.cloud.triggers.vertex_ai import AgentEngineQueryJobTrigger +from airflow.triggers.base import TriggerEvent + +GCP_PROJECT = "test-project" +GCP_LOCATION = "us-central1" +GCP_CONN_ID = "test-conn" +IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] +AGENT_ENGINE_ID = "123" +QUERY_OPERATION_NAME = "projects/test-project/locations/us-central1/operations/query-123" +QUERY_OPERATION_ID = "query-123" +CHECK_QUERY_CONFIG = {"retrieve_result": True} + + +class FakeModel: + def __init__(self, payload): + self.payload = payload + for key, value in payload.items(): + setattr(self, key, value) + + def model_dump(self, mode="json"): + return self.payload + + +@pytest.fixture +def query_job_trigger(): + return AgentEngineQueryJobTrigger( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=QUERY_OPERATION_ID, + config=CHECK_QUERY_CONFIG, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + poll_interval=1, + timeout=60, + ) + + +class TestAgentEngineQueryJobTrigger: + def test_serialize(self, query_job_trigger): + assert query_job_trigger.serialize() == ( + "airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineQueryJobTrigger", + { + "project_id": GCP_PROJECT, + "location": GCP_LOCATION, + "operation_id": QUERY_OPERATION_ID, + "config": CHECK_QUERY_CONFIG, + "gcp_conn_id": GCP_CONN_ID, + "impersonation_chain": IMPERSONATION_CHAIN, + "poll_interval": 1, + "timeout": 60, + }, + ) + + def test_serialize_with_pydantic_config(self): + pydantic_config = FakeModel(CHECK_QUERY_CONFIG) + trigger = AgentEngineQueryJobTrigger( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=QUERY_OPERATION_ID, + config=pydantic_config, + gcp_conn_id=GCP_CONN_ID, + poll_interval=1, + ) + _, kwargs = trigger.serialize() + assert kwargs["config"] == CHECK_QUERY_CONFIG + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) + async def test_run_loop_return_success_event(self, mock_hook, query_job_trigger): + query_job = { + "operation_name": QUERY_OPERATION_NAME, + "output_gcs_uri": "gs://test-bucket/query-output/output.json", + "status": "SUCCESS", + "result": "done", + } + mock_hook.return_value.check_query_agent_engine_job.return_value = FakeModel(query_job) + + event = await query_job_trigger.run().asend(None) + + mock_hook.return_value.check_query_agent_engine_job.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=QUERY_OPERATION_ID, + config=CHECK_QUERY_CONFIG, + ) + assert event == TriggerEvent( + { + "status": "success", + "message": "Agent Engine query job completed", + "query_job": query_job, + } + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.asyncio.sleep", autospec=True) + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) + async def test_run_loop_polls_until_success(self, mock_hook, mock_sleep, query_job_trigger): + running_job = FakeModel({"operation_name": QUERY_OPERATION_NAME, "status": "RUNNING"}) + success_job = FakeModel({"operation_name": QUERY_OPERATION_NAME, "status": "SUCCESS"}) + mock_hook.return_value.check_query_agent_engine_job.side_effect = [ + running_job, + running_job, + success_job, + ] + + event = await query_job_trigger.run().asend(None) + + assert mock_hook.return_value.check_query_agent_engine_job.call_count == 3 + assert mock_sleep.call_count == 2 + mock_sleep.assert_awaited_with(1) + assert event == TriggerEvent( + { + "status": "success", + "message": "Agent Engine query job completed", + "query_job": {"operation_name": QUERY_OPERATION_NAME, "status": "SUCCESS"}, + } + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) + async def test_run_loop_return_failed_event(self, mock_hook, query_job_trigger): + query_job = {"operation_name": QUERY_OPERATION_NAME, "status": "FAILED"} + mock_hook.return_value.check_query_agent_engine_job.return_value = FakeModel(query_job) + + event = await query_job_trigger.run().asend(None) + + assert event == TriggerEvent( + { + "status": "error", + "message": f"Agent Engine query job {QUERY_OPERATION_ID} failed.", + "query_job": query_job, + } + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.asyncio.sleep", autospec=True) + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) + async def test_run_loop_return_timeout_event(self, mock_hook, mock_sleep, query_job_trigger): + query_job_trigger.timeout = -1 + mock_hook.return_value.check_query_agent_engine_job.return_value = FakeModel({"status": "RUNNING"}) + + event = await query_job_trigger.run().asend(None) + + mock_sleep.assert_not_called() + assert event == TriggerEvent( + { + "status": "timeout", + "message": f"Timed out waiting for Agent Engine query job {QUERY_OPERATION_ID}", + "query_job": {"status": "RUNNING"}, + } + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) + async def test_run_loop_return_error_event_for_unexpected_status(self, mock_hook, query_job_trigger): + query_job = {"operation_name": QUERY_OPERATION_NAME, "status": "CANCELLED"} + mock_hook.return_value.check_query_agent_engine_job.return_value = FakeModel(query_job) + + event = await query_job_trigger.run().asend(None) + + assert event == TriggerEvent( + { + "status": "error", + "message": ( + f"Agent Engine query job {QUERY_OPERATION_ID} completed with unexpected status CANCELLED." + ), + "query_job": query_job, + } + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) + async def test_run_loop_return_error_event(self, mock_hook, query_job_trigger): + mock_hook.return_value.check_query_agent_engine_job.side_effect = RuntimeError("boom") + + event = await query_job_trigger.run().asend(None) + + assert event == TriggerEvent( + { + "status": "error", + "message": "Failed while polling Agent Engine query job: boom", + "query_job": {"operation_id": QUERY_OPERATION_ID}, + } + ) diff --git a/uv.lock b/uv.lock index eac0d996a7baf..d411a08586152 100644 --- a/uv.lock +++ b/uv.lock @@ -5449,6 +5449,7 @@ dependencies = [ { name = "google-cloud-videointelligence" }, { name = "google-cloud-vision" }, { name = "google-cloud-workflows" }, + { name = "google-genai" }, { name = "grpcio-gcp" }, { name = "httpx" }, { name = "immutabledict" }, @@ -5642,6 +5643,7 @@ requires-dist = [ { name = "google-cloud-videointelligence", specifier = ">=2.11.0" }, { name = "google-cloud-vision", specifier = ">=3.4.0" }, { name = "google-cloud-workflows", specifier = ">=1.10.0" }, + { name = "google-genai", specifier = ">=2.8.0" }, { name = "grpcio-gcp", specifier = ">=0.2.2" }, { name = "httpx", specifier = ">=0.25.0" }, { name = "immutabledict", specifier = ">=4.2.0" },