From 1df7a6bb8191a4397fca68e394f5c73b72924c7c Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Fri, 12 Jun 2026 15:14:58 -0300 Subject: [PATCH 01/32] Add Vertex AI Agent Engine operators (Create, Get, Query, Update, Delete) Adds hook, operators, deferrable trigger, unit tests, and a system test for the Vertex AI Agent Engine (Reasoning Engine) API. The Query operator uses the REST endpoint directly because the public SDK only exposes run_query_job for async/GCS-backed queries, while the private _query helper is affected by a response-parsing bug in google-genai==2.8.0. --- .../google/docs/operators/cloud/vertex_ai.rst | 50 +++ providers/google/provider.yaml | 2 + .../cloud/hooks/vertex_ai/agent_engine.py | 203 +++++++++ .../cloud/operators/vertex_ai/agent_engine.py | 392 ++++++++++++++++++ .../google/cloud/triggers/vertex_ai.py | 89 ++++ .../providers/google/get_provider_info.py | 2 + .../example_vertex_ai_agent_engine.py | 153 +++++++ .../hooks/vertex_ai/test_agent_engine.py | 233 +++++++++++ .../operators/vertex_ai/test_agent_engine.py | 330 +++++++++++++++ .../triggers/test_vertex_ai_agent_engine.py | 113 +++++ 10 files changed, 1567 insertions(+) create mode 100644 providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py create mode 100644 providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py create mode 100644 providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py create mode 100644 providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py create mode 100644 providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py create mode 100644 providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py diff --git a/providers/google/docs/operators/cloud/vertex_ai.rst b/providers/google/docs/operators/cloud/vertex_ai.rst index 246788c9e9302..2a82edd9cd63b 100644 --- a/providers/google/docs/operators/cloud/vertex_ai.rst +++ b/providers/google/docs/operators/cloud/vertex_ai.rst @@ -26,6 +26,56 @@ With Vertex AI, both AutoML training and custom training are available options. Whichever option you choose for training, you can save models, deploy models, and request predictions with Vertex AI. +Managing Agent Engines +^^^^^^^^^^^^^^^^^^^^^^ + +To create a Vertex AI Agent Engine you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.CreateAgentEngineOperator`. + +.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_create_agent_engine_operator] + :end-before: [END how_to_cloud_vertex_ai_create_agent_engine_operator] + +To get an Agent Engine you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.GetAgentEngineOperator`. + +.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_get_agent_engine_operator] + :end-before: [END how_to_cloud_vertex_ai_get_agent_engine_operator] + +To query an Agent Engine you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.QueryAgentEngineOperator`. + +.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_query_agent_engine_operator] + :end-before: [END how_to_cloud_vertex_ai_query_agent_engine_operator] + +To update an Agent Engine you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.UpdateAgentEngineOperator`. + +.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_update_agent_engine_operator] + :end-before: [END how_to_cloud_vertex_ai_update_agent_engine_operator] + +To delete an Agent Engine you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.DeleteAgentEngineOperator`. +By default, the operator waits until the Agent Engine no longer exists. The same operation +can be performed in the deferrable mode. + +.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_delete_agent_engine_operator] + :end-before: [END how_to_cloud_vertex_ai_delete_agent_engine_operator] + Creating Datasets ^^^^^^^^^^^^^^^^^ diff --git a/providers/google/provider.yaml b/providers/google/provider.yaml index 44b6ca5efd6e9..b72ddd525382d 100644 --- a/providers/google/provider.yaml +++ b/providers/google/provider.yaml @@ -615,6 +615,7 @@ operators: - airflow.providers.google.cloud.operators.vertex_ai.generative_model - airflow.providers.google.cloud.operators.vertex_ai.feature_store - airflow.providers.google.cloud.operators.vertex_ai.ray + - airflow.providers.google.cloud.operators.vertex_ai.agent_engine - integration-name: Google Data Studio python-modules: - airflow.providers.google.cloud.operators.looker @@ -901,6 +902,7 @@ hooks: - airflow.providers.google.cloud.hooks.vertex_ai.generative_model - airflow.providers.google.cloud.hooks.vertex_ai.prediction_service - airflow.providers.google.cloud.hooks.vertex_ai.feature_store + - airflow.providers.google.cloud.hooks.vertex_ai.agent_engine - airflow.providers.google.cloud.hooks.vertex_ai.ray - integration-name: Google Data Studio python-modules: diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py new file mode 100644 index 0000000000000..ab534ed285814 --- /dev/null +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -0,0 +1,203 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Vertex AI Agent Engine hook.""" + +from __future__ import annotations + +import time +from collections.abc import Sequence +from typing import Any + +import google.auth.transport.requests +from google.genai.errors import ClientError +from vertexai import Client + +from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook + + +class AgentEngineHook(GoogleBaseHook): + """Hook for Google Cloud Vertex AI Agent Engine APIs.""" + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + def get_agent_engine_client(self, project_id: str, location: str): + """Return the Vertex AI Agent Engine client.""" + return Client( + project=project_id, + location=location, + credentials=self.get_credentials(), + ).agent_engines + + @GoogleBaseHook.fallback_to_default_project_id + def create_agent_engine( + self, + location: str, + agent: Any | None = None, + agent_engine: Any | None = None, + config: Any | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> Any: + """ + Create an Agent Engine. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param agent: Optional. The agent object to deploy. + :param agent_engine: Optional. Deprecated alias for ``agent``. + :param config: Optional. Configuration for the Agent Engine. + """ + client = self.get_agent_engine_client(project_id=project_id, location=location) + return client.create(agent=agent, agent_engine=agent_engine, config=config) + + @GoogleBaseHook.fallback_to_default_project_id + def get_agent_engine( + self, + location: str, + name: str, + project_id: str = PROVIDE_PROJECT_ID, + ) -> Any: + """ + Get an Agent Engine. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param name: Required. The Agent Engine resource name. + """ + client = self.get_agent_engine_client(project_id=project_id, location=location) + return client.get(name=name) + + @GoogleBaseHook.fallback_to_default_project_id + def query_agent_engine( + self, + location: str, + name: str, + config: Any | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> Any: + """ + Query an Agent Engine. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param name: Required. The Agent Engine resource name. + :param config: Optional. Configuration for the query request (``class_method``, ``input``). + """ + # Use the REST API directly rather than the SDK's run_query_job (requires GCS) or + # _query (private method; triggers a Pydantic parsing bug in google-genai 2.8.0 when + # the response output type is Any). + cfg = config if isinstance(config, dict) else {} + body: dict[str, Any] = {"classMethod": cfg.get("class_method", "query")} + if "input" in cfg: + body["input"] = cfg["input"] + + url = f"https://{location}-aiplatform.googleapis.com/v1beta1/{name}:query" + session = google.auth.transport.requests.AuthorizedSession(self.get_credentials()) + response = session.post(url, json=body) + response.raise_for_status() + data = response.json() + return data.get("output", data) + + @GoogleBaseHook.fallback_to_default_project_id + def update_agent_engine( + self, + location: str, + name: str, + config: Any, + agent: Any | None = None, + agent_engine: Any | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> Any: + """ + Update an Agent Engine. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param name: Required. The Agent Engine resource name. + :param agent: Optional. The updated agent object to deploy. + :param agent_engine: Optional. Deprecated alias for ``agent``. + :param config: Required. Configuration for the Agent Engine update. + """ + client = self.get_agent_engine_client(project_id=project_id, location=location) + return client.update(name=name, agent=agent, agent_engine=agent_engine, config=config) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_agent_engine( + self, + location: str, + name: str, + force: bool | None = None, + config: Any | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> Any: + """ + Delete an Agent Engine. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param name: Required. The Agent Engine resource name. + :param force: Optional. Whether to delete child resources. + :param config: Optional. Additional deletion configuration. + """ + client = self.get_agent_engine_client(project_id=project_id, location=location) + return client.delete(name=name, force=force, config=config) + + def is_agent_engine_deleted(self, project_id: str, location: str, name: str) -> bool: + """Return whether an Agent Engine no longer exists.""" + try: + self.get_agent_engine(project_id=project_id, location=location, name=name) + except ClientError as err: + if getattr(err, "code", None) == 404: + return True + raise + return False + + def wait_for_agent_engine_deleted( + self, + project_id: str, + location: str, + name: str, + poll_interval: float, + timeout: float | None = None, + ) -> None: + """ + Wait until an Agent Engine no longer exists. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param name: Required. The Agent Engine resource name. + :param poll_interval: Time, in seconds, to wait between checks. + :param timeout: Optional timeout, in seconds. + """ + start_time = time.monotonic() + while True: + if self.is_agent_engine_deleted(project_id=project_id, location=location, name=name): + return + if timeout is not None and time.monotonic() - start_time > timeout: + raise TimeoutError(f"Timed out waiting for Agent Engine {name} to be deleted") + self.log.info("Waiting for Agent Engine %s to be deleted.", name) + time.sleep(poll_interval) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py new file mode 100644 index 0000000000000..295990a4dc93a --- /dev/null +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -0,0 +1,392 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Vertex AI Agent Engine operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from airflow.providers.common.compat.sdk import conf +from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineHook +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator +from airflow.providers.google.cloud.triggers.vertex_ai import AgentEngineDeleteTrigger + +if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import Context + + +def validate_execute_complete_event(event: dict[str, Any] | None) -> dict[str, Any]: + if event is None: + raise RuntimeError("No event received in trigger callback") + return event + + +def _serialize_value(value: Any) -> Any: + if hasattr(value, "model_dump"): + return value.model_dump(mode="json") + if isinstance(value, dict): + return {key: _serialize_value(item) for key, item in value.items()} + if isinstance(value, list): + return [_serialize_value(item) for item in value] + if isinstance(value, tuple): + return tuple(_serialize_value(item) for item in value) + return value + + +def _serialize_agent_engine(agent_engine: Any) -> dict[str, Any]: + api_resource = getattr(agent_engine, "api_resource", None) + if api_resource is not None: + return _serialize_value(api_resource) + return _serialize_value(agent_engine) + + +class CreateAgentEngineOperator(GoogleCloudBaseOperator): + """ + Create a Vertex AI Agent Engine. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param agent: Optional. The agent object to deploy. + :param agent_engine: Optional. Deprecated alias for ``agent``. + :param config: Optional. Configuration for the Agent Engine. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term credentials. + """ + + template_fields = ( + "project_id", + "location", + "agent", + "agent_engine", + "config", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + project_id: str, + location: str, + agent: Any | None = None, + agent_engine: Any | None = None, + config: Any | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.agent = agent + self.agent_engine = agent_engine + self.config = config + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> AgentEngineHook: + return AgentEngineHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + def execute(self, context: Context) -> dict[str, Any]: + agent_engine = self.hook.create_agent_engine( + project_id=self.project_id, + location=self.location, + agent=self.agent, + agent_engine=self.agent_engine, + config=self.config, + ) + result = _serialize_agent_engine(agent_engine) + context["ti"].xcom_push(key="agent_engine_name", value=result.get("name")) + return result + + +class GetAgentEngineOperator(GoogleCloudBaseOperator): + """ + Get a Vertex AI Agent Engine. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param name: Required. The Agent Engine resource name. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term credentials. + """ + + template_fields = ("project_id", "location", "name", "gcp_conn_id", "impersonation_chain") + + def __init__( + self, + *, + project_id: str, + location: str, + name: str, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.name = name + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> AgentEngineHook: + return AgentEngineHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + def execute(self, context: Context) -> dict[str, Any]: + agent_engine = self.hook.get_agent_engine( + project_id=self.project_id, + location=self.location, + name=self.name, + ) + return _serialize_agent_engine(agent_engine) + + +class QueryAgentEngineOperator(GoogleCloudBaseOperator): + """ + Query a Vertex AI Agent Engine. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param name: Required. The Agent Engine resource name. + :param config: Optional. Configuration for the query request (``class_method``, ``input``). + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term credentials. + """ + + template_fields = ("project_id", "location", "name", "config", "gcp_conn_id", "impersonation_chain") + + def __init__( + self, + *, + project_id: str, + location: str, + name: str, + config: Any | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.name = name + self.config = config + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> AgentEngineHook: + return AgentEngineHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + def execute(self, context: Context) -> Any: + return self.hook.query_agent_engine( + project_id=self.project_id, + location=self.location, + name=self.name, + config=self.config, + ) + + +class UpdateAgentEngineOperator(GoogleCloudBaseOperator): + """ + Update a Vertex AI Agent Engine. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param name: Required. The Agent Engine resource name. + :param agent: Optional. The updated agent object to deploy. + :param agent_engine: Optional. Deprecated alias for ``agent``. + :param config: Required. Configuration for the Agent Engine update. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term credentials. + """ + + template_fields = ( + "project_id", + "location", + "name", + "agent", + "agent_engine", + "config", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + project_id: str, + location: str, + name: str, + config: Any, + agent: Any | None = None, + agent_engine: Any | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.name = name + self.agent = agent + self.agent_engine = agent_engine + self.config = config + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> AgentEngineHook: + return AgentEngineHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + def execute(self, context: Context) -> dict[str, Any]: + agent_engine = self.hook.update_agent_engine( + project_id=self.project_id, + location=self.location, + name=self.name, + agent=self.agent, + agent_engine=self.agent_engine, + config=self.config, + ) + return _serialize_agent_engine(agent_engine) + + +class DeleteAgentEngineOperator(GoogleCloudBaseOperator): + """ + Delete a Vertex AI Agent Engine. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param name: Required. The Agent Engine resource name. + :param force: Optional. Whether to delete child resources. + :param config: Optional. Additional deletion configuration. + :param wait_for_completion: Whether to wait until the Agent Engine no longer exists. + :param poll_interval: Time, in seconds, to wait between checks. + :param timeout: Optional timeout, in seconds. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term credentials. + :param deferrable: Run operator in the deferrable mode. + """ + + template_fields = ( + "project_id", + "location", + "name", + "force", + "config", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + project_id: str, + location: str, + name: str, + force: bool | None = None, + config: Any | None = None, + wait_for_completion: bool = True, + poll_interval: float = 30, + timeout: float | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.name = name + self.force = force + self.config = config + self.wait_for_completion = wait_for_completion + self.poll_interval = poll_interval + self.timeout = timeout + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.deferrable = deferrable + + @cached_property + def hook(self) -> AgentEngineHook: + return AgentEngineHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + def execute(self, context: Context) -> dict[str, Any]: + operation = self.hook.delete_agent_engine( + project_id=self.project_id, + location=self.location, + name=self.name, + force=self.force, + config=self.config, + ) + result = _serialize_value(operation) + if not self.wait_for_completion: + return result + + if self.deferrable: + self.defer( + trigger=AgentEngineDeleteTrigger( + project_id=self.project_id, + location=self.location, + name=self.name, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + poll_interval=self.poll_interval, + timeout=self.timeout, + ), + method_name="execute_complete", + kwargs={"operation": result}, + ) + + self.hook.wait_for_agent_engine_deleted( + project_id=self.project_id, + location=self.location, + name=self.name, + poll_interval=self.poll_interval, + timeout=self.timeout, + ) + return result + + def execute_complete( + self, context: Context, event: dict[str, Any] | None = None, operation: dict[str, Any] | None = None + ) -> dict[str, Any]: + validated_event = validate_execute_complete_event(event) + if validated_event["status"] != "success": + raise RuntimeError(validated_event["message"]) + self.log.info("Agent Engine %s deleted.", validated_event["name"]) + return operation or {} diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py index d39e96734fee5..1378da063e4a0 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +import asyncio +import time from collections.abc import AsyncIterator, Sequence from functools import cached_property from typing import TYPE_CHECKING, Any @@ -29,6 +31,7 @@ ) from airflow.providers.common.compat.sdk import AirflowException +from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineHook from airflow.providers.google.cloud.hooks.vertex_ai.batch_prediction_job import BatchPredictionJobAsyncHook from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobAsyncHook from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import ( @@ -126,6 +129,92 @@ def _serialize_job(self, job: Any) -> Any: return self.job_serializer_class.to_dict(job) +class AgentEngineDeleteTrigger(BaseTrigger): + """Trigger that waits until a Vertex AI Agent Engine no longer exists.""" + + def __init__( + self, + project_id: str, + location: str, + name: str, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + poll_interval: float = 30, + timeout: float | None = None, + ): + super().__init__() + self.project_id = project_id + self.location = location + self.name = name + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.poll_interval = poll_interval + self.timeout = timeout + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineDeleteTrigger", + { + "project_id": self.project_id, + "location": self.location, + "name": self.name, + "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, + "poll_interval": self.poll_interval, + "timeout": self.timeout, + }, + ) + + @cached_property + def hook(self) -> AgentEngineHook: + return AgentEngineHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + start_time = time.monotonic() + try: + while True: + deleted = await asyncio.to_thread( + self.hook.is_agent_engine_deleted, + project_id=self.project_id, + location=self.location, + name=self.name, + ) + if deleted: + yield TriggerEvent( + { + "status": "success", + "message": "Agent Engine deleted", + "name": self.name, + } + ) + return + + if self.timeout is not None and time.monotonic() - start_time > self.timeout: + yield TriggerEvent( + { + "status": "error", + "message": f"Timed out waiting for Agent Engine {self.name} to be deleted", + "name": self.name, + } + ) + return + + self.log.info("Waiting for Agent Engine %s to be deleted.", self.name) + await asyncio.sleep(self.poll_interval) + except Exception as err: + self.log.exception("Exception occurred while waiting for Agent Engine deletion.") + yield TriggerEvent( + { + "status": "error", + "message": str(err), + "name": self.name, + } + ) + + class CreateHyperparameterTuningJobTrigger(BaseVertexAIJobTrigger): """CreateHyperparameterTuningJobTrigger run on the trigger worker to perform create operation.""" diff --git a/providers/google/src/airflow/providers/google/get_provider_info.py b/providers/google/src/airflow/providers/google/get_provider_info.py index 61571bc6831f3..e61b48a972277 100644 --- a/providers/google/src/airflow/providers/google/get_provider_info.py +++ b/providers/google/src/airflow/providers/google/get_provider_info.py @@ -672,6 +672,7 @@ def get_provider_info(): "airflow.providers.google.cloud.operators.vertex_ai.generative_model", "airflow.providers.google.cloud.operators.vertex_ai.feature_store", "airflow.providers.google.cloud.operators.vertex_ai.ray", + "airflow.providers.google.cloud.operators.vertex_ai.agent_engine", ], }, { @@ -1044,6 +1045,7 @@ def get_provider_info(): "airflow.providers.google.cloud.hooks.vertex_ai.prediction_service", "airflow.providers.google.cloud.hooks.vertex_ai.feature_store", "airflow.providers.google.cloud.hooks.vertex_ai.ray", + "airflow.providers.google.cloud.hooks.vertex_ai.agent_engine", ], }, { diff --git a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py new file mode 100644 index 0000000000000..c83728fbc4199 --- /dev/null +++ b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py @@ -0,0 +1,153 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google Vertex AI Agent Engine operations. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.providers.google.cloud.operators.vertex_ai.agent_engine import ( + CreateAgentEngineOperator, + DeleteAgentEngineOperator, + GetAgentEngineOperator, + QueryAgentEngineOperator, + UpdateAgentEngineOperator, +) + +try: + from airflow.sdk import DAG, TriggerRule +except ImportError: + # Compatibility for Airflow < 3.1 + from airflow.models.dag import DAG # type: ignore[attr-defined,no-redef,assignment] + from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef,attr-defined] + +DAG_ID = "vertex_ai_agent_engine_operations" +LOCATION = "us-central1" +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") +CONTAINER_URI = os.environ.get( + "SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_CONTAINER_URI", + "us-central1-docker.pkg.dev/example-project/example-repository/example-agent:latest", +) + +AGENT_ENGINE_NAME = "{{ task_instance.xcom_pull(task_ids='create_agent_engine')['name'] }}" +DISPLAY_NAME = f"airflow-agent-engine-{ENV_ID}" + +QUERY_CONFIG = { + "class_method": "query", + "input": {"prompt": "Respond with a short acknowledgement."}, +} + +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "vertex_ai", "agent_engine"], +) as dag: + # [START how_to_cloud_vertex_ai_create_agent_engine_operator] + create_agent_engine = CreateAgentEngineOperator( + task_id="create_agent_engine", + project_id=PROJECT_ID, + location=LOCATION, + config={ + "display_name": DISPLAY_NAME, + "description": "Airflow system test Agent Engine", + "agent_framework": "custom", + "min_instances": 0, + "max_instances": 1, + "resource_limits": {"cpu": "1", "memory": "1Gi"}, + "container_spec": {"image_uri": CONTAINER_URI}, + "class_methods": [ + { + "name": "query", + "api_mode": "", + }, + ], + }, + ) + # [END how_to_cloud_vertex_ai_create_agent_engine_operator] + + # [START how_to_cloud_vertex_ai_get_agent_engine_operator] + get_agent_engine = GetAgentEngineOperator( + task_id="get_agent_engine", + project_id=PROJECT_ID, + location=LOCATION, + name=AGENT_ENGINE_NAME, + ) + # [END how_to_cloud_vertex_ai_get_agent_engine_operator] + + # [START how_to_cloud_vertex_ai_query_agent_engine_operator] + query_agent_engine = QueryAgentEngineOperator( + task_id="query_agent_engine", + project_id=PROJECT_ID, + location=LOCATION, + name=AGENT_ENGINE_NAME, + config=QUERY_CONFIG, + ) + # [END how_to_cloud_vertex_ai_query_agent_engine_operator] + + # [START how_to_cloud_vertex_ai_update_agent_engine_operator] + update_agent_engine = UpdateAgentEngineOperator( + task_id="update_agent_engine", + project_id=PROJECT_ID, + location=LOCATION, + name=AGENT_ENGINE_NAME, + config={ + "display_name": f"{DISPLAY_NAME}-updated", + "description": "Updated Airflow system test Agent Engine", + }, + ) + # [END how_to_cloud_vertex_ai_update_agent_engine_operator] + + # [START how_to_cloud_vertex_ai_delete_agent_engine_operator] + delete_agent_engine = DeleteAgentEngineOperator( + task_id="delete_agent_engine", + project_id=PROJECT_ID, + location=LOCATION, + name=AGENT_ENGINE_NAME, + force=True, + deferrable=True, + trigger_rule=TriggerRule.ALL_DONE, + ) + # [END how_to_cloud_vertex_ai_delete_agent_engine_operator] + + ( + create_agent_engine + >> get_agent_engine + >> query_agent_engine + >> update_agent_engine + >> delete_agent_engine + ) + + # ### Everything below this line is not part of example ### + # ### Just for system tests purpose ### + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: contributing-docs/testing/system_tests.rst) +test_run = get_test_run(dag) diff --git a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py new file mode 100644 index 0000000000000..b22d7a82ce376 --- /dev/null +++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py @@ -0,0 +1,233 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest +from google.genai.errors import ClientError + +from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineHook + +from unit.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id + +BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" +AGENT_ENGINE_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.agent_engine.{}" + +TEST_GCP_CONN_ID = "test-gcp-conn-id" +GCP_PROJECT = "test-project" +GCP_LOCATION = "us-central1" +AGENT_ENGINE_NAME = "projects/test-project/locations/us-central1/reasoningEngines/123" +CONFIG = {"display_name": "test-agent-engine"} +QUERY_CONFIG = {"class_method": "query", "input": {"prompt": "hello"}} + + +class TestAgentEngineHookWithDefaultProjectId: + def setup_method(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id + ): + self.hook = AgentEngineHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @mock.patch(AGENT_ENGINE_STRING.format("Client"), autospec=True) + def test_get_agent_engine_client(self, mock_client): + self.hook.get_credentials = mock.Mock(return_value=mock.sentinel.credentials, spec=()) + + result = self.hook.get_agent_engine_client(project_id=GCP_PROJECT, location=GCP_LOCATION) + + mock_client.assert_called_once_with( + project=GCP_PROJECT, + location=GCP_LOCATION, + credentials=mock.sentinel.credentials, + ) + assert result == mock_client.return_value.agent_engines + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_create_agent_engine(self, mock_get_client): + result = self.hook.create_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + config=CONFIG, + ) + + mock_get_client.assert_called_once_with(self.hook, project_id=GCP_PROJECT, location=GCP_LOCATION) + mock_get_client.return_value.create.assert_called_once_with( + agent=None, + agent_engine=None, + config=CONFIG, + ) + assert result == mock_get_client.return_value.create.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_get_agent_engine(self, mock_get_client): + result = self.hook.get_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + ) + + mock_get_client.return_value.get.assert_called_once_with(name=AGENT_ENGINE_NAME) + assert result == mock_get_client.return_value.get.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("google.auth.transport.requests.AuthorizedSession"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_credentials"), autospec=True) + def test_query_agent_engine(self, mock_get_credentials, mock_session_cls): + mock_response = mock.Mock() + mock_response.json.return_value = {"output": {"answer": "hello"}} + mock_session_cls.return_value.post.return_value = mock_response + + result = self.hook.query_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + config=QUERY_CONFIG, + ) + + mock_session_cls.assert_called_once_with(mock_get_credentials.return_value) + mock_session_cls.return_value.post.assert_called_once_with( + f"https://{GCP_LOCATION}-aiplatform.googleapis.com/v1beta1/{AGENT_ENGINE_NAME}:query", + json={"classMethod": "query", "input": {"prompt": "hello"}}, + ) + mock_response.raise_for_status.assert_called_once_with() + assert result == {"answer": "hello"} + + @mock.patch(AGENT_ENGINE_STRING.format("google.auth.transport.requests.AuthorizedSession"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_credentials"), autospec=True) + def test_query_agent_engine_returns_full_response_when_output_missing( + self, mock_get_credentials, mock_session_cls + ): + full_response = {"someOtherField": "value"} + mock_response = mock.Mock() + mock_response.json.return_value = full_response + mock_session_cls.return_value.post.return_value = mock_response + + result = self.hook.query_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + config=QUERY_CONFIG, + ) + + assert result == full_response + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_update_agent_engine(self, mock_get_client): + result = self.hook.update_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + config=CONFIG, + ) + + mock_get_client.return_value.update.assert_called_once_with( + name=AGENT_ENGINE_NAME, + agent=None, + agent_engine=None, + config=CONFIG, + ) + assert result == mock_get_client.return_value.update.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_update_agent_engine_with_deprecated_agent_engine_alias(self, mock_get_client): + agent_engine = object() + + result = self.hook.update_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + agent_engine=agent_engine, + config=CONFIG, + ) + + mock_get_client.return_value.update.assert_called_once_with( + name=AGENT_ENGINE_NAME, + agent=None, + agent_engine=agent_engine, + config=CONFIG, + ) + assert result == mock_get_client.return_value.update.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_delete_agent_engine(self, mock_get_client): + result = self.hook.delete_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + force=True, + config=CONFIG, + ) + + mock_get_client.return_value.delete.assert_called_once_with( + name=AGENT_ENGINE_NAME, + force=True, + config=CONFIG, + ) + assert result == mock_get_client.return_value.delete.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine"), autospec=True) + def test_is_agent_engine_deleted_returns_false_when_resource_exists(self, mock_get_agent_engine): + assert not self.hook.is_agent_engine_deleted( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + ) + mock_get_agent_engine.assert_called_once_with( + self.hook, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + ) + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine"), autospec=True) + def test_is_agent_engine_deleted_returns_true_on_404(self, mock_get_agent_engine): + mock_get_agent_engine.side_effect = ClientError(code=404, response_json={"error": "not found"}) + + assert self.hook.is_agent_engine_deleted( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + ) + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine"), autospec=True) + def test_is_agent_engine_deleted_reraises_non_404(self, mock_get_agent_engine): + mock_get_agent_engine.side_effect = ClientError(code=500, response_json={"error": "server error"}) + + with pytest.raises(ClientError) as err: + self.hook.is_agent_engine_deleted( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + ) + + assert err.value.code == 500 + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine"), autospec=True) + def test_is_agent_engine_deleted_reraises_non_404_with_404_in_message(self, mock_get_agent_engine): + mock_get_agent_engine.side_effect = ClientError( + code=500, + response_json={"error": "server error for resource 404"}, + ) + + with pytest.raises(ClientError) as err: + self.hook.is_agent_engine_deleted( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + ) + + assert err.value.code == 500 diff --git a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py new file mode 100644 index 0000000000000..8c88e1186a669 --- /dev/null +++ b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py @@ -0,0 +1,330 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.providers.common.compat.sdk import TaskDeferred +from airflow.providers.google.cloud.operators.vertex_ai.agent_engine import ( + CreateAgentEngineOperator, + DeleteAgentEngineOperator, + GetAgentEngineOperator, + QueryAgentEngineOperator, + UpdateAgentEngineOperator, +) + +AGENT_ENGINE_PATH = "airflow.providers.google.cloud.operators.vertex_ai.agent_engine.{}" + +TASK_ID = "test_task_id" +GCP_PROJECT = "test-project" +GCP_LOCATION = "us-central1" +GCP_CONN_ID = "test-conn" +IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] +AGENT_ENGINE_NAME = "projects/test-project/locations/us-central1/reasoningEngines/123" +CONFIG = {"display_name": "test-agent-engine"} +QUERY_CONFIG = {"class_method": "query", "input": {"prompt": "hello"}} +OPERATION = {"name": "operations/delete-123", "done": False} + + +class FakeModel: + def __init__(self, payload): + self.payload = payload + + def model_dump(self, mode="json"): + return self.payload + + +class FakeAgentEngine: + def __init__(self, payload): + self.api_resource = FakeModel(payload) + + +@pytest.fixture +def context(): + return {"ti": mock.Mock(spec_set=["xcom_push"])} + + +def assert_hook_created(mock_hook): + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + +class TestCreateAgentEngineOperator: + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute(self, mock_hook, context): + mock_hook.return_value.create_agent_engine.return_value = FakeAgentEngine( + {"name": AGENT_ENGINE_NAME, "display_name": "test-agent-engine"} + ) + op = CreateAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + config=CONFIG, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + assert_hook_created(mock_hook) + mock_hook.return_value.create_agent_engine.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent=None, + agent_engine=None, + config=CONFIG, + ) + context["ti"].xcom_push.assert_called_once_with( + key="agent_engine_name", + value=AGENT_ENGINE_NAME, + ) + assert result == {"name": AGENT_ENGINE_NAME, "display_name": "test-agent-engine"} + + +class TestGetAgentEngineOperator: + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute(self, mock_hook, context): + mock_hook.return_value.get_agent_engine.return_value = FakeAgentEngine({"name": AGENT_ENGINE_NAME}) + op = GetAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + mock_hook.return_value.get_agent_engine.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + ) + assert result == {"name": AGENT_ENGINE_NAME} + + +class TestQueryAgentEngineOperator: + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute(self, mock_hook, context): + mock_hook.return_value.query_agent_engine.return_value = {"output": "hello"} + op = QueryAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + config=QUERY_CONFIG, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + mock_hook.return_value.query_agent_engine.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + config=QUERY_CONFIG, + ) + assert result == {"output": "hello"} + + +class TestUpdateAgentEngineOperator: + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute(self, mock_hook, context): + mock_hook.return_value.update_agent_engine.return_value = FakeAgentEngine( + {"name": AGENT_ENGINE_NAME, "display_name": "updated-agent-engine"} + ) + op = UpdateAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + config=CONFIG, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + mock_hook.return_value.update_agent_engine.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + agent=None, + agent_engine=None, + config=CONFIG, + ) + assert result == {"name": AGENT_ENGINE_NAME, "display_name": "updated-agent-engine"} + + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_with_deprecated_agent_engine_alias(self, mock_hook, context): + agent_engine = object() + mock_hook.return_value.update_agent_engine.return_value = FakeAgentEngine( + {"name": AGENT_ENGINE_NAME, "display_name": "updated-agent-engine"} + ) + op = UpdateAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + agent_engine=agent_engine, + config=CONFIG, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + mock_hook.return_value.update_agent_engine.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + agent=None, + agent_engine=agent_engine, + config=CONFIG, + ) + assert result == {"name": AGENT_ENGINE_NAME, "display_name": "updated-agent-engine"} + + +class TestDeleteAgentEngineOperator: + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_without_wait(self, mock_hook, context): + mock_hook.return_value.delete_agent_engine.return_value = FakeModel(OPERATION) + op = DeleteAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + force=True, + config=CONFIG, + wait_for_completion=False, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + mock_hook.return_value.delete_agent_engine.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + force=True, + config=CONFIG, + ) + mock_hook.return_value.wait_for_agent_engine_deleted.assert_not_called() + assert result == OPERATION + + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_waits_until_deleted(self, mock_hook, context): + mock_hook.return_value.delete_agent_engine.return_value = FakeModel(OPERATION) + op = DeleteAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + wait_for_completion=True, + poll_interval=1, + timeout=60, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + mock_hook.return_value.wait_for_agent_engine_deleted.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + poll_interval=1, + timeout=60, + ) + assert result == OPERATION + + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineDeleteTrigger"), autospec=True) + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_deferrable(self, mock_hook, mock_trigger, context): + mock_hook.return_value.delete_agent_engine.return_value = FakeModel(OPERATION) + op = DeleteAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + wait_for_completion=True, + poll_interval=1, + timeout=60, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc: + op.execute(context=context) + + mock_trigger.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + poll_interval=1, + timeout=60, + ) + assert exc.value.kwargs == {"operation": OPERATION} + + def test_execute_complete_success(self, context): + op = DeleteAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + ) + + result = op.execute_complete( + context=context, + event={"status": "success", "message": "Agent Engine deleted", "name": AGENT_ENGINE_NAME}, + operation=OPERATION, + ) + + assert result == OPERATION + + def test_execute_complete_error(self, context): + op = DeleteAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + ) + + with pytest.raises(RuntimeError, match="boom"): + op.execute_complete(context=context, event={"status": "error", "message": "boom"}) + + def test_execute_complete_without_event(self, context): + op = DeleteAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + ) + + with pytest.raises(RuntimeError, match="No event received in trigger callback"): + op.execute_complete(context=context) diff --git a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py new file mode 100644 index 0000000000000..d0ab368904f65 --- /dev/null +++ b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py @@ -0,0 +1,113 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.providers.google.cloud.triggers.vertex_ai import AgentEngineDeleteTrigger +from airflow.triggers.base import TriggerEvent + +GCP_PROJECT = "test-project" +GCP_LOCATION = "us-central1" +GCP_CONN_ID = "test-conn" +IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] +AGENT_ENGINE_NAME = "projects/test-project/locations/us-central1/reasoningEngines/123" + + +@pytest.fixture +def delete_trigger(): + return AgentEngineDeleteTrigger( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + poll_interval=1, + timeout=60, + ) + + +class TestAgentEngineDeleteTrigger: + def test_serialize(self, delete_trigger): + assert delete_trigger.serialize() == ( + "airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineDeleteTrigger", + { + "project_id": GCP_PROJECT, + "location": GCP_LOCATION, + "name": AGENT_ENGINE_NAME, + "gcp_conn_id": GCP_CONN_ID, + "impersonation_chain": IMPERSONATION_CHAIN, + "poll_interval": 1, + "timeout": 60, + }, + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineHook", autospec=True) + async def test_run_loop_return_success_event(self, mock_hook, delete_trigger): + mock_hook.return_value.is_agent_engine_deleted.return_value = True + + event = await delete_trigger.run().asend(None) + + mock_hook.return_value.is_agent_engine_deleted.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + ) + assert event == TriggerEvent( + { + "status": "success", + "message": "Agent Engine deleted", + "name": AGENT_ENGINE_NAME, + } + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.asyncio.sleep", autospec=True) + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineHook", autospec=True) + async def test_run_loop_return_timeout_event(self, mock_hook, mock_sleep, delete_trigger): + delete_trigger.timeout = -1 + mock_hook.return_value.is_agent_engine_deleted.return_value = False + + event = await delete_trigger.run().asend(None) + + mock_sleep.assert_not_called() + assert event == TriggerEvent( + { + "status": "error", + "message": f"Timed out waiting for Agent Engine {AGENT_ENGINE_NAME} to be deleted", + "name": AGENT_ENGINE_NAME, + } + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineHook", autospec=True) + async def test_run_loop_return_error_event(self, mock_hook, delete_trigger): + mock_hook.return_value.is_agent_engine_deleted.side_effect = RuntimeError("boom") + + event = await delete_trigger.run().asend(None) + + assert event == TriggerEvent( + { + "status": "error", + "message": "boom", + "name": AGENT_ENGINE_NAME, + } + ) From c5b29bed8147302d91f247b3c3c5bd449d652b87 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Fri, 12 Jun 2026 15:42:44 -0300 Subject: [PATCH 02/32] Fix Vertex AI Agent Engine delete trigger --- .../cloud/hooks/vertex_ai/agent_engine.py | 34 ++++++++++++++++++- .../google/cloud/triggers/vertex_ai.py | 9 +++-- .../triggers/test_vertex_ai_agent_engine.py | 6 ++-- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index ab534ed285814..ccf8d72c6f1c9 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -24,10 +24,15 @@ from typing import Any import google.auth.transport.requests +from asgiref.sync import sync_to_async from google.genai.errors import ClientError from vertexai import Client -from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook +from airflow.providers.google.common.hooks.base_google import ( + PROVIDE_PROJECT_ID, + GoogleBaseAsyncHook, + GoogleBaseHook, +) class AgentEngineHook(GoogleBaseHook): @@ -201,3 +206,30 @@ def wait_for_agent_engine_deleted( raise TimeoutError(f"Timed out waiting for Agent Engine {name} to be deleted") self.log.info("Waiting for Agent Engine %s to be deleted.", name) time.sleep(poll_interval) + + +class AgentEngineAsyncHook(GoogleBaseAsyncHook): + """Async hook for Google Cloud Vertex AI Agent Engine APIs.""" + + sync_hook_class = AgentEngineHook + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ): + super().__init__( + gcp_conn_id=gcp_conn_id, + impersonation_chain=impersonation_chain, + **kwargs, + ) + + async def is_agent_engine_deleted(self, project_id: str, location: str, name: str) -> bool: + """Return whether an Agent Engine no longer exists.""" + sync_hook = await self.get_sync_hook() + return await sync_to_async(sync_hook.is_agent_engine_deleted)( + project_id=project_id, + location=location, + name=name, + ) diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py index 1378da063e4a0..ab98dbe4d20a2 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py @@ -31,7 +31,7 @@ ) from airflow.providers.common.compat.sdk import AirflowException -from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineHook +from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineAsyncHook from airflow.providers.google.cloud.hooks.vertex_ai.batch_prediction_job import BatchPredictionJobAsyncHook from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobAsyncHook from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import ( @@ -166,8 +166,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: ) @cached_property - def hook(self) -> AgentEngineHook: - return AgentEngineHook( + def async_hook(self) -> AgentEngineAsyncHook: + return AgentEngineAsyncHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) @@ -176,8 +176,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: start_time = time.monotonic() try: while True: - deleted = await asyncio.to_thread( - self.hook.is_agent_engine_deleted, + deleted = await self.async_hook.is_agent_engine_deleted( project_id=self.project_id, location=self.location, name=self.name, diff --git a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py index d0ab368904f65..93088799c4f72 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py @@ -60,7 +60,7 @@ def test_serialize(self, delete_trigger): ) @pytest.mark.asyncio - @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineHook", autospec=True) + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) async def test_run_loop_return_success_event(self, mock_hook, delete_trigger): mock_hook.return_value.is_agent_engine_deleted.return_value = True @@ -81,7 +81,7 @@ async def test_run_loop_return_success_event(self, mock_hook, delete_trigger): @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.asyncio.sleep", autospec=True) - @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineHook", autospec=True) + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) async def test_run_loop_return_timeout_event(self, mock_hook, mock_sleep, delete_trigger): delete_trigger.timeout = -1 mock_hook.return_value.is_agent_engine_deleted.return_value = False @@ -98,7 +98,7 @@ async def test_run_loop_return_timeout_event(self, mock_hook, mock_sleep, delete ) @pytest.mark.asyncio - @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineHook", autospec=True) + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) async def test_run_loop_return_error_event(self, mock_hook, delete_trigger): mock_hook.return_value.is_agent_engine_deleted.side_effect = RuntimeError("boom") From 15e78692b702eba27772856dde94434901624870 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Fri, 12 Jun 2026 16:34:31 -0300 Subject: [PATCH 03/32] Fix Google provider metadata ordering --- .../google/src/airflow/providers/google/get_provider_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/google/src/airflow/providers/google/get_provider_info.py b/providers/google/src/airflow/providers/google/get_provider_info.py index e61b48a972277..92966a4b3679b 100644 --- a/providers/google/src/airflow/providers/google/get_provider_info.py +++ b/providers/google/src/airflow/providers/google/get_provider_info.py @@ -1044,8 +1044,8 @@ def get_provider_info(): "airflow.providers.google.cloud.hooks.vertex_ai.generative_model", "airflow.providers.google.cloud.hooks.vertex_ai.prediction_service", "airflow.providers.google.cloud.hooks.vertex_ai.feature_store", - "airflow.providers.google.cloud.hooks.vertex_ai.ray", "airflow.providers.google.cloud.hooks.vertex_ai.agent_engine", + "airflow.providers.google.cloud.hooks.vertex_ai.ray", ], }, { From 06893eea87b9317820b24857b4d9b4826195e4be Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Sat, 13 Jun 2026 13:03:41 -0300 Subject: [PATCH 04/32] Trigger CI From 1382d0a325620ff4a0e4b412bb864c9211d7c951 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Sat, 13 Jun 2026 14:43:02 -0300 Subject: [PATCH 05/32] Parse JSON string input in query_agent_engine to handle Jinja-templated XCom values --- .../cloud/hooks/vertex_ai/agent_engine.py | 11 ++++- .../hooks/vertex_ai/test_agent_engine.py | 42 +++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index ccf8d72c6f1c9..366cc5f1a2aca 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -19,6 +19,7 @@ from __future__ import annotations +import json import time from collections.abc import Sequence from typing import Any @@ -118,7 +119,15 @@ def query_agent_engine( cfg = config if isinstance(config, dict) else {} body: dict[str, Any] = {"classMethod": cfg.get("class_method", "query")} if "input" in cfg: - body["input"] = cfg["input"] + input_val = cfg["input"] + if isinstance(input_val, str): + try: + input_val = json.loads(input_val) + except json.JSONDecodeError as err: + raise ValueError("Agent Engine query input must be a JSON object.") from err + if not isinstance(input_val, dict): + raise ValueError("Agent Engine query input must be a JSON object.") + body["input"] = input_val url = f"https://{location}-aiplatform.googleapis.com/v1beta1/{name}:query" session = google.auth.transport.requests.AuthorizedSession(self.get_credentials()) diff --git a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py index b22d7a82ce376..c2c277c552a97 100644 --- a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py @@ -125,6 +125,48 @@ def test_query_agent_engine_returns_full_response_when_output_missing( assert result == full_response + @mock.patch(AGENT_ENGINE_STRING.format("google.auth.transport.requests.AuthorizedSession"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_credentials"), autospec=True) + def test_query_agent_engine_parses_json_string_input(self, mock_get_credentials, mock_session_cls): + mock_response = mock.Mock() + mock_response.json.return_value = {"output": {"answer": "hello"}} + mock_session_cls.return_value.post.return_value = mock_response + + result = self.hook.query_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + config={"class_method": "query", "input": '{"prompt": "hello"}'}, + ) + + mock_session_cls.return_value.post.assert_called_once_with( + f"https://{GCP_LOCATION}-aiplatform.googleapis.com/v1beta1/{AGENT_ENGINE_NAME}:query", + json={"classMethod": "query", "input": {"prompt": "hello"}}, + ) + assert result == {"answer": "hello"} + + @pytest.mark.parametrize( + "input_value", + [ + "test string", + '"test string"', + '["prompt", "hello"]', + 1, + ["prompt", "hello"], + ], + ) + @mock.patch(AGENT_ENGINE_STRING.format("google.auth.transport.requests.AuthorizedSession"), autospec=True) + def test_query_agent_engine_rejects_non_object_input(self, mock_session_cls, input_value): + with pytest.raises(ValueError, match="Agent Engine query input must be a JSON object."): + self.hook.query_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + config={"class_method": "query", "input": input_value}, + ) + + mock_session_cls.assert_not_called() + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) def test_update_agent_engine(self, mock_get_client): result = self.hook.update_agent_engine( From 1203c44b5786daafcb6385e9c16805a40ee49c76 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Mon, 15 Jun 2026 00:25:37 -0300 Subject: [PATCH 06/32] Address review comments on Vertex AI Agent Engine operators - Fix error message: JSONDecodeError now raises "must be valid JSON" - Replace hardcoded REST URL with sdk_client._api_client.request(), letting the SDK handle URL construction and auth (removes google.auth.transport dep) - Add request_timeout param to query_agent_engine and QueryAgentEngineOperator - Add poll_interval=30 default to wait_for_agent_engine_deleted - Fix docstrings: project_id marked Optional and moved to end of param list - Clarify force=None docstring: distinguishes "not specified" from False - Remove validate_execute_complete_event helper (inlined in execute_complete) - Remove redundant xcom_push("agent_engine_name") from CreateAgentEngineOperator - Trigger: use "timeout" status (not "error") and >= for boundary check - execute_complete: explicit branches for success / timeout / error --- .../cloud/hooks/vertex_ai/agent_engine.py | 52 ++++++++------ .../cloud/operators/vertex_ai/agent_engine.py | 27 ++++---- .../google/cloud/triggers/vertex_ai.py | 4 +- .../hooks/vertex_ai/test_agent_engine.py | 67 ++++++++++--------- .../operators/vertex_ai/test_agent_engine.py | 19 ++++-- 5 files changed, 95 insertions(+), 74 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index 366cc5f1a2aca..8b5d95d41fcf4 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -24,8 +24,8 @@ from collections.abc import Sequence from typing import Any -import google.auth.transport.requests from asgiref.sync import sync_to_async +from google.genai._api_client import HttpOptions from google.genai.errors import ClientError from vertexai import Client @@ -71,11 +71,12 @@ def create_agent_engine( """ Create an Agent Engine. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param agent: Optional. The agent object to deploy. :param agent_engine: Optional. Deprecated alias for ``agent``. :param config: Optional. Configuration for the Agent Engine. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. """ client = self.get_agent_engine_client(project_id=project_id, location=location) return client.create(agent=agent, agent_engine=agent_engine, config=config) @@ -90,9 +91,10 @@ def get_agent_engine( """ Get an Agent Engine. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param name: Required. The Agent Engine resource name. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. """ client = self.get_agent_engine_client(project_id=project_id, location=location) return client.get(name=name) @@ -103,19 +105,23 @@ def query_agent_engine( location: str, name: str, config: Any | None = None, + request_timeout: float | None = None, project_id: str = PROVIDE_PROJECT_ID, ) -> Any: """ Query an Agent Engine. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param name: Required. The Agent Engine resource name. :param config: Optional. Configuration for the query request (``class_method``, ``input``). + :param request_timeout: Optional. Timeout in seconds for the HTTP request. Defaults to no timeout. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. """ - # Use the REST API directly rather than the SDK's run_query_job (requires GCS) or - # _query (private method; triggers a Pydantic parsing bug in google-genai 2.8.0 when - # the response output type is Any). + # Use the SDK's _api_client.request() directly rather than the SDK's run_query_job + # (requires GCS) or _query (private method; triggers a Pydantic parsing bug in + # google-genai 2.8.0 when the response output type is Any). Calling request() bypasses + # Pydantic parsing while still letting the SDK handle URL construction and auth. cfg = config if isinstance(config, dict) else {} body: dict[str, Any] = {"classMethod": cfg.get("class_method", "query")} if "input" in cfg: @@ -124,16 +130,17 @@ def query_agent_engine( try: input_val = json.loads(input_val) except json.JSONDecodeError as err: - raise ValueError("Agent Engine query input must be a JSON object.") from err + raise ValueError("Agent Engine query input must be valid JSON.") from err if not isinstance(input_val, dict): raise ValueError("Agent Engine query input must be a JSON object.") body["input"] = input_val - url = f"https://{location}-aiplatform.googleapis.com/v1beta1/{name}:query" - session = google.auth.transport.requests.AuthorizedSession(self.get_credentials()) - response = session.post(url, json=body) - response.raise_for_status() - data = response.json() + sdk_client = self.get_agent_engine_client(project_id=project_id, location=location) + http_options = HttpOptions( + timeout=int(request_timeout * 1000) if request_timeout is not None else None + ) + response = sdk_client._api_client.request("post", f"{name}:query", body, http_options) + data = {} if not response.body else json.loads(response.body) return data.get("output", data) @GoogleBaseHook.fallback_to_default_project_id @@ -149,12 +156,13 @@ def update_agent_engine( """ Update an Agent Engine. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param name: Required. The Agent Engine resource name. + :param config: Required. Configuration for the Agent Engine update. :param agent: Optional. The updated agent object to deploy. :param agent_engine: Optional. Deprecated alias for ``agent``. - :param config: Required. Configuration for the Agent Engine update. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. """ client = self.get_agent_engine_client(project_id=project_id, location=location) return client.update(name=name, agent=agent, agent_engine=agent_engine, config=config) @@ -171,11 +179,13 @@ def delete_agent_engine( """ Delete an Agent Engine. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param name: Required. The Agent Engine resource name. - :param force: Optional. Whether to delete child resources. + :param force: Optional. Whether to forcefully delete child resources. Defaults to ``False`` + when not specified. :param config: Optional. Additional deletion configuration. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. """ client = self.get_agent_engine_client(project_id=project_id, location=location) return client.delete(name=name, force=force, config=config) @@ -195,15 +205,15 @@ def wait_for_agent_engine_deleted( project_id: str, location: str, name: str, - poll_interval: float, + poll_interval: float = 30, timeout: float | None = None, ) -> None: """ Wait until an Agent Engine no longer exists. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param name: Required. The Agent Engine resource name. + :param project_id: The ID of the Google Cloud project that the service belongs to. + :param location: The ID of the Google Cloud location that the service belongs to. + :param name: The Agent Engine resource name. :param poll_interval: Time, in seconds, to wait between checks. :param timeout: Optional timeout, in seconds. """ diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py index 295990a4dc93a..c04f40afbdf79 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -32,12 +32,6 @@ from airflow.providers.common.compat.sdk import Context -def validate_execute_complete_event(event: dict[str, Any] | None) -> dict[str, Any]: - if event is None: - raise RuntimeError("No event received in trigger callback") - return event - - def _serialize_value(value: Any) -> Any: if hasattr(value, "model_dump"): return value.model_dump(mode="json") @@ -116,9 +110,7 @@ def execute(self, context: Context) -> dict[str, Any]: agent_engine=self.agent_engine, config=self.config, ) - result = _serialize_agent_engine(agent_engine) - context["ti"].xcom_push(key="agent_engine_name", value=result.get("name")) - return result + return _serialize_agent_engine(agent_engine) class GetAgentEngineOperator(GoogleCloudBaseOperator): @@ -175,6 +167,7 @@ class QueryAgentEngineOperator(GoogleCloudBaseOperator): :param location: Required. The ID of the Google Cloud location that the service belongs to. :param name: Required. The Agent Engine resource name. :param config: Optional. Configuration for the query request (``class_method``, ``input``). + :param request_timeout: Optional. Timeout in seconds for the HTTP request. Defaults to no timeout. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials. """ @@ -188,6 +181,7 @@ def __init__( location: str, name: str, config: Any | None = None, + request_timeout: float | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, @@ -197,6 +191,7 @@ def __init__( self.location = location self.name = name self.config = config + self.request_timeout = request_timeout self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain @@ -213,6 +208,7 @@ def execute(self, context: Context) -> Any: location=self.location, name=self.name, config=self.config, + request_timeout=self.request_timeout, ) @@ -385,8 +381,11 @@ def execute(self, context: Context) -> dict[str, Any]: def execute_complete( self, context: Context, event: dict[str, Any] | None = None, operation: dict[str, Any] | None = None ) -> dict[str, Any]: - validated_event = validate_execute_complete_event(event) - if validated_event["status"] != "success": - raise RuntimeError(validated_event["message"]) - self.log.info("Agent Engine %s deleted.", validated_event["name"]) - return operation or {} + if event is None: + raise RuntimeError("No event received in trigger callback") + if event["status"] == "success": + self.log.info("Agent Engine %s deleted.", event["name"]) + return operation or {} + if event["status"] == "timeout": + raise TimeoutError(event["message"]) + raise RuntimeError(event["message"]) diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py index ab98dbe4d20a2..174da1d16adf1 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py @@ -191,10 +191,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: ) return - if self.timeout is not None and time.monotonic() - start_time > self.timeout: + if self.timeout is not None and time.monotonic() - start_time >= self.timeout: yield TriggerEvent( { - "status": "error", + "status": "timeout", "message": f"Timed out waiting for Agent Engine {self.name} to be deleted", "name": self.name, } diff --git a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py index c2c277c552a97..da0c929254a86 100644 --- a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import json from unittest import mock import pytest @@ -84,12 +85,11 @@ def test_get_agent_engine(self, mock_get_client): mock_get_client.return_value.get.assert_called_once_with(name=AGENT_ENGINE_NAME) assert result == mock_get_client.return_value.get.return_value - @mock.patch(AGENT_ENGINE_STRING.format("google.auth.transport.requests.AuthorizedSession"), autospec=True) - @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_credentials"), autospec=True) - def test_query_agent_engine(self, mock_get_credentials, mock_session_cls): - mock_response = mock.Mock() - mock_response.json.return_value = {"output": {"answer": "hello"}} - mock_session_cls.return_value.post.return_value = mock_response + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_query_agent_engine(self, mock_get_client): + mock_get_client.return_value._api_client.request.return_value.body = json.dumps( + {"output": {"answer": "hello"}} + ) result = self.hook.query_agent_engine( project_id=GCP_PROJECT, @@ -98,23 +98,18 @@ def test_query_agent_engine(self, mock_get_credentials, mock_session_cls): config=QUERY_CONFIG, ) - mock_session_cls.assert_called_once_with(mock_get_credentials.return_value) - mock_session_cls.return_value.post.assert_called_once_with( - f"https://{GCP_LOCATION}-aiplatform.googleapis.com/v1beta1/{AGENT_ENGINE_NAME}:query", - json={"classMethod": "query", "input": {"prompt": "hello"}}, + mock_get_client.return_value._api_client.request.assert_called_once_with( + "post", + f"{AGENT_ENGINE_NAME}:query", + {"classMethod": "query", "input": {"prompt": "hello"}}, + mock.ANY, ) - mock_response.raise_for_status.assert_called_once_with() assert result == {"answer": "hello"} - @mock.patch(AGENT_ENGINE_STRING.format("google.auth.transport.requests.AuthorizedSession"), autospec=True) - @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_credentials"), autospec=True) - def test_query_agent_engine_returns_full_response_when_output_missing( - self, mock_get_credentials, mock_session_cls - ): + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_query_agent_engine_returns_full_response_when_output_missing(self, mock_get_client): full_response = {"someOtherField": "value"} - mock_response = mock.Mock() - mock_response.json.return_value = full_response - mock_session_cls.return_value.post.return_value = mock_response + mock_get_client.return_value._api_client.request.return_value.body = json.dumps(full_response) result = self.hook.query_agent_engine( project_id=GCP_PROJECT, @@ -125,12 +120,11 @@ def test_query_agent_engine_returns_full_response_when_output_missing( assert result == full_response - @mock.patch(AGENT_ENGINE_STRING.format("google.auth.transport.requests.AuthorizedSession"), autospec=True) - @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_credentials"), autospec=True) - def test_query_agent_engine_parses_json_string_input(self, mock_get_credentials, mock_session_cls): - mock_response = mock.Mock() - mock_response.json.return_value = {"output": {"answer": "hello"}} - mock_session_cls.return_value.post.return_value = mock_response + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_query_agent_engine_parses_json_string_input(self, mock_get_client): + mock_get_client.return_value._api_client.request.return_value.body = json.dumps( + {"output": {"answer": "hello"}} + ) result = self.hook.query_agent_engine( project_id=GCP_PROJECT, @@ -139,24 +133,33 @@ def test_query_agent_engine_parses_json_string_input(self, mock_get_credentials, config={"class_method": "query", "input": '{"prompt": "hello"}'}, ) - mock_session_cls.return_value.post.assert_called_once_with( - f"https://{GCP_LOCATION}-aiplatform.googleapis.com/v1beta1/{AGENT_ENGINE_NAME}:query", - json={"classMethod": "query", "input": {"prompt": "hello"}}, + mock_get_client.return_value._api_client.request.assert_called_once_with( + "post", + f"{AGENT_ENGINE_NAME}:query", + {"classMethod": "query", "input": {"prompt": "hello"}}, + mock.ANY, ) assert result == {"answer": "hello"} + def test_query_agent_engine_rejects_invalid_json_string_input(self): + with pytest.raises(ValueError, match="Agent Engine query input must be valid JSON."): + self.hook.query_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + config={"class_method": "query", "input": "not valid json"}, + ) + @pytest.mark.parametrize( "input_value", [ - "test string", '"test string"', '["prompt", "hello"]', 1, ["prompt", "hello"], ], ) - @mock.patch(AGENT_ENGINE_STRING.format("google.auth.transport.requests.AuthorizedSession"), autospec=True) - def test_query_agent_engine_rejects_non_object_input(self, mock_session_cls, input_value): + def test_query_agent_engine_rejects_non_object_input(self, input_value): with pytest.raises(ValueError, match="Agent Engine query input must be a JSON object."): self.hook.query_agent_engine( project_id=GCP_PROJECT, @@ -165,8 +168,6 @@ def test_query_agent_engine_rejects_non_object_input(self, mock_session_cls, inp config={"class_method": "query", "input": input_value}, ) - mock_session_cls.assert_not_called() - @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) def test_update_agent_engine(self, mock_get_client): result = self.hook.update_agent_engine( diff --git a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py index 8c88e1186a669..6c40d0a166bbe 100644 --- a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py @@ -93,10 +93,6 @@ def test_execute(self, mock_hook, context): agent_engine=None, config=CONFIG, ) - context["ti"].xcom_push.assert_called_once_with( - key="agent_engine_name", - value=AGENT_ENGINE_NAME, - ) assert result == {"name": AGENT_ENGINE_NAME, "display_name": "test-agent-engine"} @@ -144,6 +140,7 @@ def test_execute(self, mock_hook, context): location=GCP_LOCATION, name=AGENT_ENGINE_NAME, config=QUERY_CONFIG, + request_timeout=None, ) assert result == {"output": "hello"} @@ -318,6 +315,20 @@ def test_execute_complete_error(self, context): with pytest.raises(RuntimeError, match="boom"): op.execute_complete(context=context, event={"status": "error", "message": "boom"}) + def test_execute_complete_timeout(self, context): + op = DeleteAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + name=AGENT_ENGINE_NAME, + ) + + with pytest.raises(TimeoutError, match="timed out"): + op.execute_complete( + context=context, + event={"status": "timeout", "message": "timed out", "name": AGENT_ENGINE_NAME}, + ) + def test_execute_complete_without_event(self, context): op = DeleteAgentEngineOperator( task_id=TASK_ID, From eaba45096649b67f4b0212e841bf540572fdb84f Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Mon, 15 Jun 2026 10:41:47 -0300 Subject: [PATCH 07/32] Fix Agent Engine delete trigger timeout test --- .../unit/google/cloud/triggers/test_vertex_ai_agent_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py index 93088799c4f72..82f18acf58e7d 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py @@ -91,7 +91,7 @@ async def test_run_loop_return_timeout_event(self, mock_hook, mock_sleep, delete mock_sleep.assert_not_called() assert event == TriggerEvent( { - "status": "error", + "status": "timeout", "message": f"Timed out waiting for Agent Engine {AGENT_ENGINE_NAME} to be deleted", "name": AGENT_ENGINE_NAME, } From 634f99dffa42af6a06e53ae1aa495a7a481f8765 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Mon, 15 Jun 2026 23:15:08 -0300 Subject: [PATCH 08/32] Use Agent Engine ID in Vertex AI operators --- .../cloud/hooks/vertex_ai/agent_engine.py | 51 ++++++++++++------ .../cloud/operators/vertex_ai/agent_engine.py | 53 +++++++++++-------- .../google/cloud/triggers/vertex_ai.py | 20 +++---- .../example_vertex_ai_agent_engine.py | 10 ++-- .../hooks/vertex_ai/test_agent_engine.py | 29 +++++----- .../operators/vertex_ai/test_agent_engine.py | 45 +++++++++------- .../triggers/test_vertex_ai_agent_engine.py | 16 +++--- 7 files changed, 128 insertions(+), 96 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index 8b5d95d41fcf4..5215e6e284041 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -59,6 +59,11 @@ def get_agent_engine_client(self, project_id: str, location: str): credentials=self.get_credentials(), ).agent_engines + @staticmethod + def build_agent_engine_name(project_id: str, location: str, agent_engine_id: str) -> str: + """Build a fully qualified Agent Engine resource name.""" + return f"projects/{project_id}/locations/{location}/reasoningEngines/{agent_engine_id}" + @GoogleBaseHook.fallback_to_default_project_id def create_agent_engine( self, @@ -85,25 +90,26 @@ def create_agent_engine( def get_agent_engine( self, location: str, - name: str, + agent_engine_id: str, project_id: str = PROVIDE_PROJECT_ID, ) -> Any: """ Get an Agent Engine. :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param name: Required. The Agent Engine resource name. + :param agent_engine_id: Required. The Agent Engine ID. :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project configured in the connection. """ client = self.get_agent_engine_client(project_id=project_id, location=location) + name = self.build_agent_engine_name(project_id, location, agent_engine_id) return client.get(name=name) @GoogleBaseHook.fallback_to_default_project_id def query_agent_engine( self, location: str, - name: str, + agent_engine_id: str, config: Any | None = None, request_timeout: float | None = None, project_id: str = PROVIDE_PROJECT_ID, @@ -112,7 +118,7 @@ def query_agent_engine( Query an Agent Engine. :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param name: Required. The Agent Engine resource name. + :param agent_engine_id: Required. The Agent Engine ID. :param config: Optional. Configuration for the query request (``class_method``, ``input``). :param request_timeout: Optional. Timeout in seconds for the HTTP request. Defaults to no timeout. :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project @@ -139,6 +145,7 @@ def query_agent_engine( http_options = HttpOptions( timeout=int(request_timeout * 1000) if request_timeout is not None else None ) + name = self.build_agent_engine_name(project_id, location, agent_engine_id) response = sdk_client._api_client.request("post", f"{name}:query", body, http_options) data = {} if not response.body else json.loads(response.body) return data.get("output", data) @@ -147,7 +154,7 @@ def query_agent_engine( def update_agent_engine( self, location: str, - name: str, + agent_engine_id: str, config: Any, agent: Any | None = None, agent_engine: Any | None = None, @@ -157,7 +164,7 @@ def update_agent_engine( Update an Agent Engine. :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param name: Required. The Agent Engine resource name. + :param agent_engine_id: Required. The Agent Engine ID. :param config: Required. Configuration for the Agent Engine update. :param agent: Optional. The updated agent object to deploy. :param agent_engine: Optional. Deprecated alias for ``agent``. @@ -165,13 +172,14 @@ def update_agent_engine( configured in the connection. """ client = self.get_agent_engine_client(project_id=project_id, location=location) + name = self.build_agent_engine_name(project_id, location, agent_engine_id) return client.update(name=name, agent=agent, agent_engine=agent_engine, config=config) @GoogleBaseHook.fallback_to_default_project_id def delete_agent_engine( self, location: str, - name: str, + agent_engine_id: str, force: bool | None = None, config: Any | None = None, project_id: str = PROVIDE_PROJECT_ID, @@ -180,7 +188,7 @@ def delete_agent_engine( Delete an Agent Engine. :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param name: Required. The Agent Engine resource name. + :param agent_engine_id: Required. The Agent Engine ID. :param force: Optional. Whether to forcefully delete child resources. Defaults to ``False`` when not specified. :param config: Optional. Additional deletion configuration. @@ -188,12 +196,17 @@ def delete_agent_engine( configured in the connection. """ client = self.get_agent_engine_client(project_id=project_id, location=location) + name = self.build_agent_engine_name(project_id, location, agent_engine_id) return client.delete(name=name, force=force, config=config) - def is_agent_engine_deleted(self, project_id: str, location: str, name: str) -> bool: + def is_agent_engine_deleted(self, project_id: str, location: str, agent_engine_id: str) -> bool: """Return whether an Agent Engine no longer exists.""" try: - self.get_agent_engine(project_id=project_id, location=location, name=name) + self.get_agent_engine( + project_id=project_id, + location=location, + agent_engine_id=agent_engine_id, + ) except ClientError as err: if getattr(err, "code", None) == 404: return True @@ -204,7 +217,7 @@ def wait_for_agent_engine_deleted( self, project_id: str, location: str, - name: str, + agent_engine_id: str, poll_interval: float = 30, timeout: float | None = None, ) -> None: @@ -213,17 +226,21 @@ def wait_for_agent_engine_deleted( :param project_id: The ID of the Google Cloud project that the service belongs to. :param location: The ID of the Google Cloud location that the service belongs to. - :param name: The Agent Engine resource name. + :param agent_engine_id: The Agent Engine ID. :param poll_interval: Time, in seconds, to wait between checks. :param timeout: Optional timeout, in seconds. """ start_time = time.monotonic() while True: - if self.is_agent_engine_deleted(project_id=project_id, location=location, name=name): + if self.is_agent_engine_deleted( + project_id=project_id, + location=location, + agent_engine_id=agent_engine_id, + ): return if timeout is not None and time.monotonic() - start_time > timeout: - raise TimeoutError(f"Timed out waiting for Agent Engine {name} to be deleted") - self.log.info("Waiting for Agent Engine %s to be deleted.", name) + raise TimeoutError(f"Timed out waiting for Agent Engine {agent_engine_id} to be deleted") + self.log.info("Waiting for Agent Engine %s to be deleted.", agent_engine_id) time.sleep(poll_interval) @@ -244,11 +261,11 @@ def __init__( **kwargs, ) - async def is_agent_engine_deleted(self, project_id: str, location: str, name: str) -> bool: + async def is_agent_engine_deleted(self, project_id: str, location: str, agent_engine_id: str) -> bool: """Return whether an Agent Engine no longer exists.""" sync_hook = await self.get_sync_hook() return await sync_to_async(sync_hook.is_agent_engine_deleted)( project_id=project_id, location=location, - name=name, + agent_engine_id=agent_engine_id, ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py index c04f40afbdf79..5c517a1dadf3d 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -119,19 +119,19 @@ class GetAgentEngineOperator(GoogleCloudBaseOperator): :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param name: Required. The Agent Engine resource name. + :param agent_engine_id: Required. The Agent Engine ID. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials. """ - template_fields = ("project_id", "location", "name", "gcp_conn_id", "impersonation_chain") + template_fields = ("project_id", "location", "agent_engine_id", "gcp_conn_id", "impersonation_chain") def __init__( self, *, project_id: str, location: str, - name: str, + agent_engine_id: str, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, @@ -139,7 +139,7 @@ def __init__( super().__init__(**kwargs) self.project_id = project_id self.location = location - self.name = name + self.agent_engine_id = agent_engine_id self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain @@ -154,7 +154,7 @@ def execute(self, context: Context) -> dict[str, Any]: agent_engine = self.hook.get_agent_engine( project_id=self.project_id, location=self.location, - name=self.name, + agent_engine_id=self.agent_engine_id, ) return _serialize_agent_engine(agent_engine) @@ -165,21 +165,28 @@ class QueryAgentEngineOperator(GoogleCloudBaseOperator): :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param name: Required. The Agent Engine resource name. + :param agent_engine_id: Required. The Agent Engine ID. :param config: Optional. Configuration for the query request (``class_method``, ``input``). :param request_timeout: Optional. Timeout in seconds for the HTTP request. Defaults to no timeout. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials. """ - template_fields = ("project_id", "location", "name", "config", "gcp_conn_id", "impersonation_chain") + template_fields = ( + "project_id", + "location", + "agent_engine_id", + "config", + "gcp_conn_id", + "impersonation_chain", + ) def __init__( self, *, project_id: str, location: str, - name: str, + agent_engine_id: str, config: Any | None = None, request_timeout: float | None = None, gcp_conn_id: str = "google_cloud_default", @@ -189,7 +196,7 @@ def __init__( super().__init__(**kwargs) self.project_id = project_id self.location = location - self.name = name + self.agent_engine_id = agent_engine_id self.config = config self.request_timeout = request_timeout self.gcp_conn_id = gcp_conn_id @@ -206,7 +213,7 @@ def execute(self, context: Context) -> Any: return self.hook.query_agent_engine( project_id=self.project_id, location=self.location, - name=self.name, + agent_engine_id=self.agent_engine_id, config=self.config, request_timeout=self.request_timeout, ) @@ -218,7 +225,7 @@ class UpdateAgentEngineOperator(GoogleCloudBaseOperator): :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param name: Required. The Agent Engine resource name. + :param agent_engine_id: Required. The Agent Engine ID. :param agent: Optional. The updated agent object to deploy. :param agent_engine: Optional. Deprecated alias for ``agent``. :param config: Required. Configuration for the Agent Engine update. @@ -229,7 +236,7 @@ class UpdateAgentEngineOperator(GoogleCloudBaseOperator): template_fields = ( "project_id", "location", - "name", + "agent_engine_id", "agent", "agent_engine", "config", @@ -242,7 +249,7 @@ def __init__( *, project_id: str, location: str, - name: str, + agent_engine_id: str, config: Any, agent: Any | None = None, agent_engine: Any | None = None, @@ -253,7 +260,7 @@ def __init__( super().__init__(**kwargs) self.project_id = project_id self.location = location - self.name = name + self.agent_engine_id = agent_engine_id self.agent = agent self.agent_engine = agent_engine self.config = config @@ -271,7 +278,7 @@ def execute(self, context: Context) -> dict[str, Any]: agent_engine = self.hook.update_agent_engine( project_id=self.project_id, location=self.location, - name=self.name, + agent_engine_id=self.agent_engine_id, agent=self.agent, agent_engine=self.agent_engine, config=self.config, @@ -285,7 +292,7 @@ class DeleteAgentEngineOperator(GoogleCloudBaseOperator): :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param name: Required. The Agent Engine resource name. + :param agent_engine_id: Required. The Agent Engine ID. :param force: Optional. Whether to delete child resources. :param config: Optional. Additional deletion configuration. :param wait_for_completion: Whether to wait until the Agent Engine no longer exists. @@ -299,7 +306,7 @@ class DeleteAgentEngineOperator(GoogleCloudBaseOperator): template_fields = ( "project_id", "location", - "name", + "agent_engine_id", "force", "config", "gcp_conn_id", @@ -311,7 +318,7 @@ def __init__( *, project_id: str, location: str, - name: str, + agent_engine_id: str, force: bool | None = None, config: Any | None = None, wait_for_completion: bool = True, @@ -325,7 +332,7 @@ def __init__( super().__init__(**kwargs) self.project_id = project_id self.location = location - self.name = name + self.agent_engine_id = agent_engine_id self.force = force self.config = config self.wait_for_completion = wait_for_completion @@ -346,7 +353,7 @@ def execute(self, context: Context) -> dict[str, Any]: operation = self.hook.delete_agent_engine( project_id=self.project_id, location=self.location, - name=self.name, + agent_engine_id=self.agent_engine_id, force=self.force, config=self.config, ) @@ -359,7 +366,7 @@ def execute(self, context: Context) -> dict[str, Any]: trigger=AgentEngineDeleteTrigger( project_id=self.project_id, location=self.location, - name=self.name, + agent_engine_id=self.agent_engine_id, gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, poll_interval=self.poll_interval, @@ -372,7 +379,7 @@ def execute(self, context: Context) -> dict[str, Any]: self.hook.wait_for_agent_engine_deleted( project_id=self.project_id, location=self.location, - name=self.name, + agent_engine_id=self.agent_engine_id, poll_interval=self.poll_interval, timeout=self.timeout, ) @@ -384,7 +391,7 @@ def execute_complete( if event is None: raise RuntimeError("No event received in trigger callback") if event["status"] == "success": - self.log.info("Agent Engine %s deleted.", event["name"]) + self.log.info("Agent Engine %s deleted.", event["agent_engine_id"]) return operation or {} if event["status"] == "timeout": raise TimeoutError(event["message"]) diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py index 174da1d16adf1..718901f1506d7 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py @@ -136,7 +136,7 @@ def __init__( self, project_id: str, location: str, - name: str, + agent_engine_id: str, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, poll_interval: float = 30, @@ -145,7 +145,7 @@ def __init__( super().__init__() self.project_id = project_id self.location = location - self.name = name + self.agent_engine_id = agent_engine_id self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.poll_interval = poll_interval @@ -157,7 +157,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: { "project_id": self.project_id, "location": self.location, - "name": self.name, + "agent_engine_id": self.agent_engine_id, "gcp_conn_id": self.gcp_conn_id, "impersonation_chain": self.impersonation_chain, "poll_interval": self.poll_interval, @@ -179,14 +179,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]: deleted = await self.async_hook.is_agent_engine_deleted( project_id=self.project_id, location=self.location, - name=self.name, + agent_engine_id=self.agent_engine_id, ) if deleted: yield TriggerEvent( { "status": "success", "message": "Agent Engine deleted", - "name": self.name, + "agent_engine_id": self.agent_engine_id, } ) return @@ -195,13 +195,15 @@ async def run(self) -> AsyncIterator[TriggerEvent]: yield TriggerEvent( { "status": "timeout", - "message": f"Timed out waiting for Agent Engine {self.name} to be deleted", - "name": self.name, + "message": ( + f"Timed out waiting for Agent Engine {self.agent_engine_id} to be deleted" + ), + "agent_engine_id": self.agent_engine_id, } ) return - self.log.info("Waiting for Agent Engine %s to be deleted.", self.name) + self.log.info("Waiting for Agent Engine %s to be deleted.", self.agent_engine_id) await asyncio.sleep(self.poll_interval) except Exception as err: self.log.exception("Exception occurred while waiting for Agent Engine deletion.") @@ -209,7 +211,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: { "status": "error", "message": str(err), - "name": self.name, + "agent_engine_id": self.agent_engine_id, } ) diff --git a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py index c83728fbc4199..f2d33d27973e5 100644 --- a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py +++ b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py @@ -49,7 +49,7 @@ "us-central1-docker.pkg.dev/example-project/example-repository/example-agent:latest", ) -AGENT_ENGINE_NAME = "{{ task_instance.xcom_pull(task_ids='create_agent_engine')['name'] }}" +AGENT_ENGINE_ID = "{{ task_instance.xcom_pull(task_ids='create_agent_engine')['name'].split('/')[-1] }}" DISPLAY_NAME = f"airflow-agent-engine-{ENV_ID}" QUERY_CONFIG = { @@ -92,7 +92,7 @@ task_id="get_agent_engine", project_id=PROJECT_ID, location=LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) # [END how_to_cloud_vertex_ai_get_agent_engine_operator] @@ -101,7 +101,7 @@ task_id="query_agent_engine", project_id=PROJECT_ID, location=LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, config=QUERY_CONFIG, ) # [END how_to_cloud_vertex_ai_query_agent_engine_operator] @@ -111,7 +111,7 @@ task_id="update_agent_engine", project_id=PROJECT_ID, location=LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, config={ "display_name": f"{DISPLAY_NAME}-updated", "description": "Updated Airflow system test Agent Engine", @@ -124,7 +124,7 @@ task_id="delete_agent_engine", project_id=PROJECT_ID, location=LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, force=True, deferrable=True, trigger_rule=TriggerRule.ALL_DONE, diff --git a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py index da0c929254a86..5598e731dca35 100644 --- a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py @@ -33,6 +33,7 @@ TEST_GCP_CONN_ID = "test-gcp-conn-id" GCP_PROJECT = "test-project" GCP_LOCATION = "us-central1" +AGENT_ENGINE_ID = "123" AGENT_ENGINE_NAME = "projects/test-project/locations/us-central1/reasoningEngines/123" CONFIG = {"display_name": "test-agent-engine"} QUERY_CONFIG = {"class_method": "query", "input": {"prompt": "hello"}} @@ -79,7 +80,7 @@ def test_get_agent_engine(self, mock_get_client): result = self.hook.get_agent_engine( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) mock_get_client.return_value.get.assert_called_once_with(name=AGENT_ENGINE_NAME) @@ -94,7 +95,7 @@ def test_query_agent_engine(self, mock_get_client): result = self.hook.query_agent_engine( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, config=QUERY_CONFIG, ) @@ -114,7 +115,7 @@ def test_query_agent_engine_returns_full_response_when_output_missing(self, mock result = self.hook.query_agent_engine( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, config=QUERY_CONFIG, ) @@ -129,7 +130,7 @@ def test_query_agent_engine_parses_json_string_input(self, mock_get_client): result = self.hook.query_agent_engine( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, config={"class_method": "query", "input": '{"prompt": "hello"}'}, ) @@ -146,7 +147,7 @@ def test_query_agent_engine_rejects_invalid_json_string_input(self): self.hook.query_agent_engine( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, config={"class_method": "query", "input": "not valid json"}, ) @@ -164,7 +165,7 @@ def test_query_agent_engine_rejects_non_object_input(self, input_value): self.hook.query_agent_engine( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, config={"class_method": "query", "input": input_value}, ) @@ -173,7 +174,7 @@ def test_update_agent_engine(self, mock_get_client): result = self.hook.update_agent_engine( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, config=CONFIG, ) @@ -192,7 +193,7 @@ def test_update_agent_engine_with_deprecated_agent_engine_alias(self, mock_get_c result = self.hook.update_agent_engine( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, agent_engine=agent_engine, config=CONFIG, ) @@ -210,7 +211,7 @@ def test_delete_agent_engine(self, mock_get_client): result = self.hook.delete_agent_engine( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, force=True, config=CONFIG, ) @@ -227,13 +228,13 @@ def test_is_agent_engine_deleted_returns_false_when_resource_exists(self, mock_g assert not self.hook.is_agent_engine_deleted( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) mock_get_agent_engine.assert_called_once_with( self.hook, project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine"), autospec=True) @@ -243,7 +244,7 @@ def test_is_agent_engine_deleted_returns_true_on_404(self, mock_get_agent_engine assert self.hook.is_agent_engine_deleted( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine"), autospec=True) @@ -254,7 +255,7 @@ def test_is_agent_engine_deleted_reraises_non_404(self, mock_get_agent_engine): self.hook.is_agent_engine_deleted( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) assert err.value.code == 500 @@ -270,7 +271,7 @@ def test_is_agent_engine_deleted_reraises_non_404_with_404_in_message(self, mock self.hook.is_agent_engine_deleted( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) assert err.value.code == 500 diff --git a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py index 6c40d0a166bbe..c636d7d878661 100644 --- a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py @@ -37,6 +37,7 @@ GCP_LOCATION = "us-central1" GCP_CONN_ID = "test-conn" IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] +AGENT_ENGINE_ID = "123" AGENT_ENGINE_NAME = "projects/test-project/locations/us-central1/reasoningEngines/123" CONFIG = {"display_name": "test-agent-engine"} QUERY_CONFIG = {"class_method": "query", "input": {"prompt": "hello"}} @@ -104,7 +105,7 @@ def test_execute(self, mock_hook, context): task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) @@ -114,7 +115,7 @@ def test_execute(self, mock_hook, context): mock_hook.return_value.get_agent_engine.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) assert result == {"name": AGENT_ENGINE_NAME} @@ -127,7 +128,7 @@ def test_execute(self, mock_hook, context): task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, config=QUERY_CONFIG, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, @@ -138,7 +139,7 @@ def test_execute(self, mock_hook, context): mock_hook.return_value.query_agent_engine.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, config=QUERY_CONFIG, request_timeout=None, ) @@ -155,7 +156,7 @@ def test_execute(self, mock_hook, context): task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, config=CONFIG, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, @@ -166,7 +167,7 @@ def test_execute(self, mock_hook, context): mock_hook.return_value.update_agent_engine.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, agent=None, agent_engine=None, config=CONFIG, @@ -183,7 +184,7 @@ def test_execute_with_deprecated_agent_engine_alias(self, mock_hook, context): task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, agent_engine=agent_engine, config=CONFIG, gcp_conn_id=GCP_CONN_ID, @@ -195,7 +196,7 @@ def test_execute_with_deprecated_agent_engine_alias(self, mock_hook, context): mock_hook.return_value.update_agent_engine.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, agent=None, agent_engine=agent_engine, config=CONFIG, @@ -211,7 +212,7 @@ def test_execute_without_wait(self, mock_hook, context): task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, force=True, config=CONFIG, wait_for_completion=False, @@ -224,7 +225,7 @@ def test_execute_without_wait(self, mock_hook, context): mock_hook.return_value.delete_agent_engine.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, force=True, config=CONFIG, ) @@ -238,7 +239,7 @@ def test_execute_waits_until_deleted(self, mock_hook, context): task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, wait_for_completion=True, poll_interval=1, timeout=60, @@ -251,7 +252,7 @@ def test_execute_waits_until_deleted(self, mock_hook, context): mock_hook.return_value.wait_for_agent_engine_deleted.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, poll_interval=1, timeout=60, ) @@ -265,7 +266,7 @@ def test_execute_deferrable(self, mock_hook, mock_trigger, context): task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, wait_for_completion=True, poll_interval=1, timeout=60, @@ -280,7 +281,7 @@ def test_execute_deferrable(self, mock_hook, mock_trigger, context): mock_trigger.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, poll_interval=1, @@ -293,12 +294,16 @@ def test_execute_complete_success(self, context): task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) result = op.execute_complete( context=context, - event={"status": "success", "message": "Agent Engine deleted", "name": AGENT_ENGINE_NAME}, + event={ + "status": "success", + "message": "Agent Engine deleted", + "agent_engine_id": AGENT_ENGINE_ID, + }, operation=OPERATION, ) @@ -309,7 +314,7 @@ def test_execute_complete_error(self, context): task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) with pytest.raises(RuntimeError, match="boom"): @@ -320,13 +325,13 @@ def test_execute_complete_timeout(self, context): task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) with pytest.raises(TimeoutError, match="timed out"): op.execute_complete( context=context, - event={"status": "timeout", "message": "timed out", "name": AGENT_ENGINE_NAME}, + event={"status": "timeout", "message": "timed out", "agent_engine_id": AGENT_ENGINE_ID}, ) def test_execute_complete_without_event(self, context): @@ -334,7 +339,7 @@ def test_execute_complete_without_event(self, context): task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) with pytest.raises(RuntimeError, match="No event received in trigger callback"): diff --git a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py index 82f18acf58e7d..c49fd34391835 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py @@ -28,7 +28,7 @@ GCP_LOCATION = "us-central1" GCP_CONN_ID = "test-conn" IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] -AGENT_ENGINE_NAME = "projects/test-project/locations/us-central1/reasoningEngines/123" +AGENT_ENGINE_ID = "123" @pytest.fixture @@ -36,7 +36,7 @@ def delete_trigger(): return AgentEngineDeleteTrigger( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, poll_interval=1, @@ -51,7 +51,7 @@ def test_serialize(self, delete_trigger): { "project_id": GCP_PROJECT, "location": GCP_LOCATION, - "name": AGENT_ENGINE_NAME, + "agent_engine_id": AGENT_ENGINE_ID, "gcp_conn_id": GCP_CONN_ID, "impersonation_chain": IMPERSONATION_CHAIN, "poll_interval": 1, @@ -69,13 +69,13 @@ async def test_run_loop_return_success_event(self, mock_hook, delete_trigger): mock_hook.return_value.is_agent_engine_deleted.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, - name=AGENT_ENGINE_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) assert event == TriggerEvent( { "status": "success", "message": "Agent Engine deleted", - "name": AGENT_ENGINE_NAME, + "agent_engine_id": AGENT_ENGINE_ID, } ) @@ -92,8 +92,8 @@ async def test_run_loop_return_timeout_event(self, mock_hook, mock_sleep, delete assert event == TriggerEvent( { "status": "timeout", - "message": f"Timed out waiting for Agent Engine {AGENT_ENGINE_NAME} to be deleted", - "name": AGENT_ENGINE_NAME, + "message": f"Timed out waiting for Agent Engine {AGENT_ENGINE_ID} to be deleted", + "agent_engine_id": AGENT_ENGINE_ID, } ) @@ -108,6 +108,6 @@ async def test_run_loop_return_error_event(self, mock_hook, delete_trigger): { "status": "error", "message": "boom", - "name": AGENT_ENGINE_NAME, + "agent_engine_id": AGENT_ENGINE_ID, } ) From caa83ce49ac50b14aa1722a71cb0166fa225d29e Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Mon, 15 Jun 2026 23:25:43 -0300 Subject: [PATCH 09/32] Log Vertex AI Agent Engine operator actions --- .../cloud/operators/vertex_ai/agent_engine.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py index 5c517a1dadf3d..436512f352704 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -103,6 +103,7 @@ def hook(self) -> AgentEngineHook: ) def execute(self, context: Context) -> dict[str, Any]: + self.log.info("Creating Agent Engine.") agent_engine = self.hook.create_agent_engine( project_id=self.project_id, location=self.location, @@ -110,7 +111,9 @@ def execute(self, context: Context) -> dict[str, Any]: agent_engine=self.agent_engine, config=self.config, ) - return _serialize_agent_engine(agent_engine) + result = _serialize_agent_engine(agent_engine) + self.log.info("Agent Engine was created.") + return result class GetAgentEngineOperator(GoogleCloudBaseOperator): @@ -151,12 +154,15 @@ def hook(self) -> AgentEngineHook: ) def execute(self, context: Context) -> dict[str, Any]: + self.log.info("Getting Agent Engine %s.", self.agent_engine_id) agent_engine = self.hook.get_agent_engine( project_id=self.project_id, location=self.location, agent_engine_id=self.agent_engine_id, ) - return _serialize_agent_engine(agent_engine) + result = _serialize_agent_engine(agent_engine) + self.log.info("Agent Engine %s was retrieved.", self.agent_engine_id) + return result class QueryAgentEngineOperator(GoogleCloudBaseOperator): @@ -210,13 +216,16 @@ def hook(self) -> AgentEngineHook: ) def execute(self, context: Context) -> Any: - return self.hook.query_agent_engine( + self.log.info("Querying Agent Engine %s.", self.agent_engine_id) + result = self.hook.query_agent_engine( project_id=self.project_id, location=self.location, agent_engine_id=self.agent_engine_id, config=self.config, request_timeout=self.request_timeout, ) + self.log.info("Agent Engine %s was queried.", self.agent_engine_id) + return result class UpdateAgentEngineOperator(GoogleCloudBaseOperator): @@ -275,6 +284,7 @@ def hook(self) -> AgentEngineHook: ) def execute(self, context: Context) -> dict[str, Any]: + self.log.info("Updating Agent Engine %s.", self.agent_engine_id) agent_engine = self.hook.update_agent_engine( project_id=self.project_id, location=self.location, @@ -283,7 +293,9 @@ def execute(self, context: Context) -> dict[str, Any]: agent_engine=self.agent_engine, config=self.config, ) - return _serialize_agent_engine(agent_engine) + result = _serialize_agent_engine(agent_engine) + self.log.info("Agent Engine %s was updated.", self.agent_engine_id) + return result class DeleteAgentEngineOperator(GoogleCloudBaseOperator): @@ -350,6 +362,7 @@ def hook(self) -> AgentEngineHook: ) def execute(self, context: Context) -> dict[str, Any]: + self.log.info("Deleting Agent Engine %s.", self.agent_engine_id) operation = self.hook.delete_agent_engine( project_id=self.project_id, location=self.location, @@ -383,6 +396,7 @@ def execute(self, context: Context) -> dict[str, Any]: poll_interval=self.poll_interval, timeout=self.timeout, ) + self.log.info("Agent Engine %s was deleted.", self.agent_engine_id) return result def execute_complete( From 64e08f1b51fd56fcf6b2b1e341a23fb2748bf255 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Mon, 15 Jun 2026 23:39:02 -0300 Subject: [PATCH 10/32] Type Vertex AI Agent Engine hook results --- .../google/cloud/hooks/vertex_ai/agent_engine.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index 5215e6e284041..58aaa489ffc72 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -22,7 +22,7 @@ import json import time from collections.abc import Sequence -from typing import Any +from typing import TYPE_CHECKING, Any from asgiref.sync import sync_to_async from google.genai._api_client import HttpOptions @@ -35,6 +35,9 @@ GoogleBaseHook, ) +if TYPE_CHECKING: + from vertexai._genai import types + class AgentEngineHook(GoogleBaseHook): """Hook for Google Cloud Vertex AI Agent Engine APIs.""" @@ -72,7 +75,7 @@ def create_agent_engine( agent_engine: Any | None = None, config: Any | None = None, project_id: str = PROVIDE_PROJECT_ID, - ) -> Any: + ) -> types.AgentEngine: """ Create an Agent Engine. @@ -92,7 +95,7 @@ def get_agent_engine( location: str, agent_engine_id: str, project_id: str = PROVIDE_PROJECT_ID, - ) -> Any: + ) -> types.AgentEngine: """ Get an Agent Engine. @@ -159,7 +162,7 @@ def update_agent_engine( agent: Any | None = None, agent_engine: Any | None = None, project_id: str = PROVIDE_PROJECT_ID, - ) -> Any: + ) -> types.AgentEngine: """ Update an Agent Engine. @@ -183,7 +186,7 @@ def delete_agent_engine( force: bool | None = None, config: Any | None = None, project_id: str = PROVIDE_PROJECT_ID, - ) -> Any: + ) -> types.DeleteAgentEngineOperation: """ Delete an Agent Engine. From 161833584d9b5859574ab9148706cb29a6811b17 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Tue, 16 Jun 2026 00:15:15 -0300 Subject: [PATCH 11/32] Tighten Vertex AI Agent Engine API surface --- providers/google/pyproject.toml | 1 + .../cloud/hooks/vertex_ai/agent_engine.py | 22 +++++--- .../cloud/operators/vertex_ai/agent_engine.py | 10 ---- .../hooks/vertex_ai/test_agent_engine.py | 51 +++++++++++-------- .../operators/vertex_ai/test_agent_engine.py | 31 ----------- 5 files changed, 44 insertions(+), 71 deletions(-) diff --git a/providers/google/pyproject.toml b/providers/google/pyproject.toml index d42cc3d3a2802..d241551a9a240 100644 --- a/providers/google/pyproject.toml +++ b/providers/google/pyproject.toml @@ -80,6 +80,7 @@ dependencies = [ "google-api-python-client>=2.0.2", "google-auth>=2.29.0", "google-auth-httplib2>=0.0.1", + "google-genai>=2.8.0", # google-cloud-aiplatform doesn't install ray for python 3.12 (issue: https://github.com/googleapis/python-aiplatform/issues/5252). # Temporarily lock in ray 2.42.0 which is compatible with python 3.12 until linked issue is solved. # Remove the ray dependency as well as google-cloud-bigquery-storage once linked issue is fixed diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index 58aaa489ffc72..853e79091af57 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -72,7 +72,6 @@ def create_agent_engine( self, location: str, agent: Any | None = None, - agent_engine: Any | None = None, config: Any | None = None, project_id: str = PROVIDE_PROJECT_ID, ) -> types.AgentEngine: @@ -81,13 +80,12 @@ def create_agent_engine( :param location: Required. The ID of the Google Cloud location that the service belongs to. :param agent: Optional. The agent object to deploy. - :param agent_engine: Optional. Deprecated alias for ``agent``. :param config: Optional. Configuration for the Agent Engine. :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project configured in the connection. """ client = self.get_agent_engine_client(project_id=project_id, location=location) - return client.create(agent=agent, agent_engine=agent_engine, config=config) + return client.create(agent=agent, config=config) @GoogleBaseHook.fallback_to_default_project_id def get_agent_engine( @@ -131,6 +129,8 @@ def query_agent_engine( # (requires GCS) or _query (private method; triggers a Pydantic parsing bug in # google-genai 2.8.0 when the response output type is Any). Calling request() bypasses # Pydantic parsing while still letting the SDK handle URL construction and auth. + # Replace with a public synchronous query API when available; tracked at + # https://github.com/apache/airflow/issues/68605 cfg = config if isinstance(config, dict) else {} body: dict[str, Any] = {"classMethod": cfg.get("class_method", "query")} if "input" in cfg: @@ -149,9 +149,17 @@ def query_agent_engine( timeout=int(request_timeout * 1000) if request_timeout is not None else None ) name = self.build_agent_engine_name(project_id, location, agent_engine_id) - response = sdk_client._api_client.request("post", f"{name}:query", body, http_options) + api_client = getattr(sdk_client, "_api_client", None) + request = getattr(api_client, "request", None) + if request is None: + raise RuntimeError( + "The Vertex AI Agent Engine SDK no longer exposes _api_client.request. " + "QueryAgentEngineOperator must be updated to use a supported synchronous query API." + ) + response = request("post", f"{name}:query", body, http_options) data = {} if not response.body else json.loads(response.body) - return data.get("output", data) + output = data.get("output") + return output if output is not None else data @GoogleBaseHook.fallback_to_default_project_id def update_agent_engine( @@ -160,7 +168,6 @@ def update_agent_engine( agent_engine_id: str, config: Any, agent: Any | None = None, - agent_engine: Any | None = None, project_id: str = PROVIDE_PROJECT_ID, ) -> types.AgentEngine: """ @@ -170,13 +177,12 @@ def update_agent_engine( :param agent_engine_id: Required. The Agent Engine ID. :param config: Required. Configuration for the Agent Engine update. :param agent: Optional. The updated agent object to deploy. - :param agent_engine: Optional. Deprecated alias for ``agent``. :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project configured in the connection. """ client = self.get_agent_engine_client(project_id=project_id, location=location) name = self.build_agent_engine_name(project_id, location, agent_engine_id) - return client.update(name=name, agent=agent, agent_engine=agent_engine, config=config) + return client.update(name=name, agent=agent, config=config) @GoogleBaseHook.fallback_to_default_project_id def delete_agent_engine( diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py index 436512f352704..36d473bc5f5ee 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -58,7 +58,6 @@ class CreateAgentEngineOperator(GoogleCloudBaseOperator): :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param agent: Optional. The agent object to deploy. - :param agent_engine: Optional. Deprecated alias for ``agent``. :param config: Optional. Configuration for the Agent Engine. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials. @@ -68,7 +67,6 @@ class CreateAgentEngineOperator(GoogleCloudBaseOperator): "project_id", "location", "agent", - "agent_engine", "config", "gcp_conn_id", "impersonation_chain", @@ -80,7 +78,6 @@ def __init__( project_id: str, location: str, agent: Any | None = None, - agent_engine: Any | None = None, config: Any | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, @@ -90,7 +87,6 @@ def __init__( self.project_id = project_id self.location = location self.agent = agent - self.agent_engine = agent_engine self.config = config self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain @@ -108,7 +104,6 @@ def execute(self, context: Context) -> dict[str, Any]: project_id=self.project_id, location=self.location, agent=self.agent, - agent_engine=self.agent_engine, config=self.config, ) result = _serialize_agent_engine(agent_engine) @@ -236,7 +231,6 @@ class UpdateAgentEngineOperator(GoogleCloudBaseOperator): :param location: Required. The ID of the Google Cloud location that the service belongs to. :param agent_engine_id: Required. The Agent Engine ID. :param agent: Optional. The updated agent object to deploy. - :param agent_engine: Optional. Deprecated alias for ``agent``. :param config: Required. Configuration for the Agent Engine update. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials. @@ -247,7 +241,6 @@ class UpdateAgentEngineOperator(GoogleCloudBaseOperator): "location", "agent_engine_id", "agent", - "agent_engine", "config", "gcp_conn_id", "impersonation_chain", @@ -261,7 +254,6 @@ def __init__( agent_engine_id: str, config: Any, agent: Any | None = None, - agent_engine: Any | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, @@ -271,7 +263,6 @@ def __init__( self.location = location self.agent_engine_id = agent_engine_id self.agent = agent - self.agent_engine = agent_engine self.config = config self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain @@ -290,7 +281,6 @@ def execute(self, context: Context) -> dict[str, Any]: location=self.location, agent_engine_id=self.agent_engine_id, agent=self.agent, - agent_engine=self.agent_engine, config=self.config, ) result = _serialize_agent_engine(agent_engine) diff --git a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py index 5598e731dca35..47e1f67b0a010 100644 --- a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py @@ -70,7 +70,6 @@ def test_create_agent_engine(self, mock_get_client): mock_get_client.assert_called_once_with(self.hook, project_id=GCP_PROJECT, location=GCP_LOCATION) mock_get_client.return_value.create.assert_called_once_with( agent=None, - agent_engine=None, config=CONFIG, ) assert result == mock_get_client.return_value.create.return_value @@ -121,6 +120,35 @@ def test_query_agent_engine_returns_full_response_when_output_missing(self, mock assert result == full_response + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_query_agent_engine_returns_full_response_when_output_is_none(self, mock_get_client): + full_response = {"output": None} + mock_get_client.return_value._api_client.request.return_value.body = json.dumps(full_response) + + result = self.hook.query_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=QUERY_CONFIG, + ) + + assert result == full_response + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_query_agent_engine_raises_when_sdk_request_helper_is_missing(self, mock_get_client): + del mock_get_client.return_value._api_client.request + + with pytest.raises( + RuntimeError, + match="The Vertex AI Agent Engine SDK no longer exposes _api_client.request", + ): + self.hook.query_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=QUERY_CONFIG, + ) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) def test_query_agent_engine_parses_json_string_input(self, mock_get_client): mock_get_client.return_value._api_client.request.return_value.body = json.dumps( @@ -181,27 +209,6 @@ def test_update_agent_engine(self, mock_get_client): mock_get_client.return_value.update.assert_called_once_with( name=AGENT_ENGINE_NAME, agent=None, - agent_engine=None, - config=CONFIG, - ) - assert result == mock_get_client.return_value.update.return_value - - @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) - def test_update_agent_engine_with_deprecated_agent_engine_alias(self, mock_get_client): - agent_engine = object() - - result = self.hook.update_agent_engine( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - agent_engine=agent_engine, - config=CONFIG, - ) - - mock_get_client.return_value.update.assert_called_once_with( - name=AGENT_ENGINE_NAME, - agent=None, - agent_engine=agent_engine, config=CONFIG, ) assert result == mock_get_client.return_value.update.return_value diff --git a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py index c636d7d878661..b7792c3dcad1d 100644 --- a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py @@ -91,7 +91,6 @@ def test_execute(self, mock_hook, context): project_id=GCP_PROJECT, location=GCP_LOCATION, agent=None, - agent_engine=None, config=CONFIG, ) assert result == {"name": AGENT_ENGINE_NAME, "display_name": "test-agent-engine"} @@ -169,36 +168,6 @@ def test_execute(self, mock_hook, context): location=GCP_LOCATION, agent_engine_id=AGENT_ENGINE_ID, agent=None, - agent_engine=None, - config=CONFIG, - ) - assert result == {"name": AGENT_ENGINE_NAME, "display_name": "updated-agent-engine"} - - @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) - def test_execute_with_deprecated_agent_engine_alias(self, mock_hook, context): - agent_engine = object() - mock_hook.return_value.update_agent_engine.return_value = FakeAgentEngine( - {"name": AGENT_ENGINE_NAME, "display_name": "updated-agent-engine"} - ) - op = UpdateAgentEngineOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - agent_engine=agent_engine, - config=CONFIG, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - - result = op.execute(context=context) - - mock_hook.return_value.update_agent_engine.assert_called_once_with( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - agent=None, - agent_engine=agent_engine, config=CONFIG, ) assert result == {"name": AGENT_ENGINE_NAME, "display_name": "updated-agent-engine"} From 0fe636952b3b1a8df6156fdc92df6740d8cfcf5f Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Tue, 16 Jun 2026 06:06:09 -0300 Subject: [PATCH 12/32] Update lock file for Agent Engine dependencies --- uv.lock | 2 ++ 1 file changed, 2 insertions(+) diff --git a/uv.lock b/uv.lock index eac0d996a7baf..d411a08586152 100644 --- a/uv.lock +++ b/uv.lock @@ -5449,6 +5449,7 @@ dependencies = [ { name = "google-cloud-videointelligence" }, { name = "google-cloud-vision" }, { name = "google-cloud-workflows" }, + { name = "google-genai" }, { name = "grpcio-gcp" }, { name = "httpx" }, { name = "immutabledict" }, @@ -5642,6 +5643,7 @@ requires-dist = [ { name = "google-cloud-videointelligence", specifier = ">=2.11.0" }, { name = "google-cloud-vision", specifier = ">=3.4.0" }, { name = "google-cloud-workflows", specifier = ">=1.10.0" }, + { name = "google-genai", specifier = ">=2.8.0" }, { name = "grpcio-gcp", specifier = ">=0.2.2" }, { name = "httpx", specifier = ">=0.25.0" }, { name = "immutabledict", specifier = ">=4.2.0" }, From 09b7cb5a79339aec5d30c1ec4bfb579a5dddc4fa Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Tue, 16 Jun 2026 07:15:50 -0300 Subject: [PATCH 13/32] Update Google provider dependency docs --- providers/google/README.rst | 1 + providers/google/docs/index.rst | 1 + 2 files changed, 2 insertions(+) diff --git a/providers/google/README.rst b/providers/google/README.rst index 6c252d575f974..842ead044e002 100644 --- a/providers/google/README.rst +++ b/providers/google/README.rst @@ -76,6 +76,7 @@ PIP package Version required ``google-api-python-client`` ``>=2.0.2`` ``google-auth`` ``>=2.29.0`` ``google-auth-httplib2`` ``>=0.0.1`` +``google-genai`` ``>=2.8.0`` ``google-cloud-aiplatform[evaluation]`` ``>=1.155.0`` ``ray[default]`` ``>=2.42.0; python_version < "3.13"`` ``ray[default]`` ``>=2.49.0; python_version >= "3.13" and python_version < "3.14"`` diff --git a/providers/google/docs/index.rst b/providers/google/docs/index.rst index ff860c759857c..90a49e2f6edda 100644 --- a/providers/google/docs/index.rst +++ b/providers/google/docs/index.rst @@ -129,6 +129,7 @@ PIP package Version required ``google-api-python-client`` ``>=2.0.2`` ``google-auth`` ``>=2.29.0`` ``google-auth-httplib2`` ``>=0.0.1`` +``google-genai`` ``>=2.8.0`` ``google-cloud-aiplatform[evaluation]`` ``>=1.155.0`` ``ray[default]`` ``>=2.42.0; python_version < "3.13"`` ``ray[default]`` ``>=2.49.0; python_version >= "3.13" and python_version < "3.14"`` From 720effc71c57300b6b543f3c0c8b249ae3a3a593 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Thu, 18 Jun 2026 20:39:42 -0300 Subject: [PATCH 14/32] Wait for Agent Engine delete operations --- .../cloud/hooks/vertex_ai/agent_engine.py | 52 +++++++++++----- .../cloud/operators/vertex_ai/agent_engine.py | 24 +++++--- .../google/cloud/triggers/vertex_ai.py | 29 +++++++-- .../hooks/vertex_ai/test_agent_engine.py | 59 +++++++++++++++++++ .../operators/vertex_ai/test_agent_engine.py | 29 +++++++-- .../triggers/test_vertex_ai_agent_engine.py | 32 ++++++++-- 6 files changed, 188 insertions(+), 37 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index 853e79091af57..78be2ccc3b7fc 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -24,6 +24,7 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, Any +import google.auth.transport.requests from asgiref.sync import sync_to_async from google.genai._api_client import HttpOptions from google.genai.errors import ClientError @@ -72,7 +73,7 @@ def create_agent_engine( self, location: str, agent: Any | None = None, - config: Any | None = None, + config: types.AgentEngineConfigOrDict | None = None, project_id: str = PROVIDE_PROJECT_ID, ) -> types.AgentEngine: """ @@ -166,7 +167,7 @@ def update_agent_engine( self, location: str, agent_engine_id: str, - config: Any, + config: types.AgentEngineConfigOrDict, agent: Any | None = None, project_id: str = PROVIDE_PROJECT_ID, ) -> types.AgentEngine: @@ -190,7 +191,7 @@ def delete_agent_engine( location: str, agent_engine_id: str, force: bool | None = None, - config: Any | None = None, + config: types.DeleteAgentEngineConfigOrDict | None = None, project_id: str = PROVIDE_PROJECT_ID, ) -> types.DeleteAgentEngineOperation: """ @@ -222,34 +223,45 @@ def is_agent_engine_deleted(self, project_id: str, location: str, agent_engine_i raise return False - def wait_for_agent_engine_deleted( + def get_agent_engine_operation(self, location: str, operation_name: str) -> dict[str, Any]: + """Return a Vertex AI Agent Engine long-running operation.""" + url = ( + operation_name + if operation_name.startswith("http") + else f"https://{location}-aiplatform.googleapis.com/v1beta1/{operation_name}" + ) + session = google.auth.transport.requests.AuthorizedSession(self.get_credentials()) + response = session.get(url) + response.raise_for_status() + return response.json() + + def wait_for_agent_engine_operation( self, - project_id: str, location: str, - agent_engine_id: str, + operation_name: str, poll_interval: float = 30, timeout: float | None = None, ) -> None: """ - Wait until an Agent Engine no longer exists. + Wait until an Agent Engine operation completes. - :param project_id: The ID of the Google Cloud project that the service belongs to. :param location: The ID of the Google Cloud location that the service belongs to. - :param agent_engine_id: The Agent Engine ID. + :param operation_name: The Agent Engine operation name. :param poll_interval: Time, in seconds, to wait between checks. :param timeout: Optional timeout, in seconds. """ start_time = time.monotonic() while True: - if self.is_agent_engine_deleted( - project_id=project_id, - location=location, - agent_engine_id=agent_engine_id, - ): + operation = self.get_agent_engine_operation(location=location, operation_name=operation_name) + if operation.get("done"): + if operation.get("error"): + raise RuntimeError( + f"Agent Engine operation {operation_name} failed: {operation['error']}" + ) return if timeout is not None and time.monotonic() - start_time > timeout: - raise TimeoutError(f"Timed out waiting for Agent Engine {agent_engine_id} to be deleted") - self.log.info("Waiting for Agent Engine %s to be deleted.", agent_engine_id) + raise TimeoutError(f"Timed out waiting for Agent Engine operation {operation_name}") + self.log.info("Waiting for Agent Engine operation %s to complete.", operation_name) time.sleep(poll_interval) @@ -278,3 +290,11 @@ async def is_agent_engine_deleted(self, project_id: str, location: str, agent_en location=location, agent_engine_id=agent_engine_id, ) + + async def get_agent_engine_operation(self, location: str, operation_name: str) -> dict[str, Any]: + """Return a Vertex AI Agent Engine long-running operation.""" + sync_hook = await self.get_sync_hook() + return await sync_to_async(sync_hook.get_agent_engine_operation)( + location=location, + operation_name=operation_name, + ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py index 36d473bc5f5ee..dbfbf3b08ffd4 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -29,6 +29,8 @@ from airflow.providers.google.cloud.triggers.vertex_ai import AgentEngineDeleteTrigger if TYPE_CHECKING: + from vertexai._genai import types + from airflow.providers.common.compat.sdk import Context @@ -44,7 +46,7 @@ def _serialize_value(value: Any) -> Any: return value -def _serialize_agent_engine(agent_engine: Any) -> dict[str, Any]: +def _serialize_agent_engine(agent_engine: types.AgentEngine) -> dict[str, Any]: api_resource = getattr(agent_engine, "api_resource", None) if api_resource is not None: return _serialize_value(api_resource) @@ -78,7 +80,7 @@ def __init__( project_id: str, location: str, agent: Any | None = None, - config: Any | None = None, + config: types.AgentEngineConfigOrDict | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, @@ -252,7 +254,7 @@ def __init__( project_id: str, location: str, agent_engine_id: str, - config: Any, + config: types.AgentEngineConfigOrDict, agent: Any | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, @@ -322,7 +324,7 @@ def __init__( location: str, agent_engine_id: str, force: bool | None = None, - config: Any | None = None, + config: types.DeleteAgentEngineConfigOrDict | None = None, wait_for_completion: bool = True, poll_interval: float = 30, timeout: float | None = None, @@ -364,6 +366,14 @@ def execute(self, context: Context) -> dict[str, Any]: if not self.wait_for_completion: return result + operation_name = getattr(operation, "name", None) + if not operation_name: + raise RuntimeError("Delete Agent Engine operation did not include an operation name.") + + if getattr(operation, "done", False): + self.log.info("Agent Engine %s was deleted.", self.agent_engine_id) + return result + if self.deferrable: self.defer( trigger=AgentEngineDeleteTrigger( @@ -374,15 +384,15 @@ def execute(self, context: Context) -> dict[str, Any]: impersonation_chain=self.impersonation_chain, poll_interval=self.poll_interval, timeout=self.timeout, + operation_name=operation_name, ), method_name="execute_complete", kwargs={"operation": result}, ) - self.hook.wait_for_agent_engine_deleted( - project_id=self.project_id, + self.hook.wait_for_agent_engine_operation( location=self.location, - agent_engine_id=self.agent_engine_id, + operation_name=operation_name, poll_interval=self.poll_interval, timeout=self.timeout, ) diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py index 718901f1506d7..2229fa143d92c 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py @@ -141,6 +141,7 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, poll_interval: float = 30, timeout: float | None = None, + operation_name: str | None = None, ): super().__init__() self.project_id = project_id @@ -150,6 +151,7 @@ def __init__( self.impersonation_chain = impersonation_chain self.poll_interval = poll_interval self.timeout = timeout + self.operation_name = operation_name def serialize(self) -> tuple[str, dict[str, Any]]: return ( @@ -162,6 +164,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "impersonation_chain": self.impersonation_chain, "poll_interval": self.poll_interval, "timeout": self.timeout, + "operation_name": self.operation_name, }, ) @@ -173,15 +176,33 @@ def async_hook(self) -> AgentEngineAsyncHook: ) async def run(self) -> AsyncIterator[TriggerEvent]: + if not self.operation_name: + yield TriggerEvent( + { + "status": "error", + "message": "Delete Agent Engine operation name is required.", + "agent_engine_id": self.agent_engine_id, + } + ) + return + start_time = time.monotonic() try: while True: - deleted = await self.async_hook.is_agent_engine_deleted( - project_id=self.project_id, + operation = await self.async_hook.get_agent_engine_operation( location=self.location, - agent_engine_id=self.agent_engine_id, + operation_name=self.operation_name, ) - if deleted: + if operation.get("done"): + if operation.get("error"): + yield TriggerEvent( + { + "status": "error", + "message": str(operation["error"]), + "agent_engine_id": self.agent_engine_id, + } + ) + return yield TriggerEvent( { "status": "success", diff --git a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py index 47e1f67b0a010..0594e56892ef1 100644 --- a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py @@ -35,6 +35,7 @@ GCP_LOCATION = "us-central1" AGENT_ENGINE_ID = "123" AGENT_ENGINE_NAME = "projects/test-project/locations/us-central1/reasoningEngines/123" +OPERATION_NAME = "projects/test-project/locations/us-central1/operations/delete-123" CONFIG = {"display_name": "test-agent-engine"} QUERY_CONFIG = {"class_method": "query", "input": {"prompt": "hello"}} @@ -230,6 +231,64 @@ def test_delete_agent_engine(self, mock_get_client): ) assert result == mock_get_client.return_value.delete.return_value + @mock.patch(AGENT_ENGINE_STRING.format("google.auth.transport.requests.AuthorizedSession"), autospec=True) + def test_get_agent_engine_operation(self, mock_session): + self.hook.get_credentials = mock.Mock(return_value=mock.sentinel.credentials, spec=()) + mock_session.return_value.get.return_value.json.return_value = {"name": OPERATION_NAME, "done": True} + + result = self.hook.get_agent_engine_operation( + location=GCP_LOCATION, + operation_name=OPERATION_NAME, + ) + + mock_session.assert_called_once_with(mock.sentinel.credentials) + mock_session.return_value.get.assert_called_once_with( + f"https://{GCP_LOCATION}-aiplatform.googleapis.com/v1beta1/{OPERATION_NAME}" + ) + mock_session.return_value.get.return_value.raise_for_status.assert_called_once_with() + assert result == {"name": OPERATION_NAME, "done": True} + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_operation"), autospec=True) + def test_wait_for_agent_engine_operation_returns_when_done(self, mock_get_operation): + mock_get_operation.return_value = {"name": OPERATION_NAME, "done": True} + + self.hook.wait_for_agent_engine_operation( + location=GCP_LOCATION, + operation_name=OPERATION_NAME, + ) + + mock_get_operation.assert_called_once_with( + self.hook, + location=GCP_LOCATION, + operation_name=OPERATION_NAME, + ) + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_operation"), autospec=True) + def test_wait_for_agent_engine_operation_raises_on_error(self, mock_get_operation): + mock_get_operation.return_value = {"name": OPERATION_NAME, "done": True, "error": {"message": "boom"}} + + with pytest.raises(RuntimeError, match="Agent Engine operation .* failed"): + self.hook.wait_for_agent_engine_operation( + location=GCP_LOCATION, + operation_name=OPERATION_NAME, + ) + + @mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("time.monotonic"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_operation"), autospec=True) + def test_wait_for_agent_engine_operation_times_out(self, mock_get_operation, mock_monotonic, mock_sleep): + mock_get_operation.return_value = {"name": OPERATION_NAME, "done": False} + mock_monotonic.side_effect = [1, 3] + + with pytest.raises(TimeoutError, match="Timed out waiting for Agent Engine operation"): + self.hook.wait_for_agent_engine_operation( + location=GCP_LOCATION, + operation_name=OPERATION_NAME, + timeout=1, + ) + + mock_sleep.assert_not_called() + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine"), autospec=True) def test_is_agent_engine_deleted_returns_false_when_resource_exists(self, mock_get_agent_engine): assert not self.hook.is_agent_engine_deleted( diff --git a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py index b7792c3dcad1d..46e774b5d65f9 100644 --- a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py @@ -47,6 +47,8 @@ class FakeModel: def __init__(self, payload): self.payload = payload + for key, value in payload.items(): + setattr(self, key, value) def model_dump(self, mode="json"): return self.payload @@ -198,7 +200,7 @@ def test_execute_without_wait(self, mock_hook, context): force=True, config=CONFIG, ) - mock_hook.return_value.wait_for_agent_engine_deleted.assert_not_called() + mock_hook.return_value.wait_for_agent_engine_operation.assert_not_called() assert result == OPERATION @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) @@ -218,15 +220,33 @@ def test_execute_waits_until_deleted(self, mock_hook, context): result = op.execute(context=context) - mock_hook.return_value.wait_for_agent_engine_deleted.assert_called_once_with( - project_id=GCP_PROJECT, + mock_hook.return_value.wait_for_agent_engine_operation.assert_called_once_with( location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, + operation_name=OPERATION["name"], poll_interval=1, timeout=60, ) assert result == OPERATION + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_does_not_wait_when_delete_operation_is_done(self, mock_hook, context): + operation = {"name": "operations/delete-123", "done": True} + mock_hook.return_value.delete_agent_engine.return_value = FakeModel(operation) + op = DeleteAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + wait_for_completion=True, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + mock_hook.return_value.wait_for_agent_engine_operation.assert_not_called() + assert result == operation + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineDeleteTrigger"), autospec=True) @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) def test_execute_deferrable(self, mock_hook, mock_trigger, context): @@ -255,6 +275,7 @@ def test_execute_deferrable(self, mock_hook, mock_trigger, context): impersonation_chain=IMPERSONATION_CHAIN, poll_interval=1, timeout=60, + operation_name=OPERATION["name"], ) assert exc.value.kwargs == {"operation": OPERATION} diff --git a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py index c49fd34391835..6fb122c6bd984 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py @@ -29,6 +29,7 @@ GCP_CONN_ID = "test-conn" IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] AGENT_ENGINE_ID = "123" +OPERATION_NAME = "projects/test-project/locations/us-central1/operations/delete-123" @pytest.fixture @@ -41,6 +42,7 @@ def delete_trigger(): impersonation_chain=IMPERSONATION_CHAIN, poll_interval=1, timeout=60, + operation_name=OPERATION_NAME, ) @@ -56,20 +58,20 @@ def test_serialize(self, delete_trigger): "impersonation_chain": IMPERSONATION_CHAIN, "poll_interval": 1, "timeout": 60, + "operation_name": OPERATION_NAME, }, ) @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) async def test_run_loop_return_success_event(self, mock_hook, delete_trigger): - mock_hook.return_value.is_agent_engine_deleted.return_value = True + mock_hook.return_value.get_agent_engine_operation.return_value = {"done": True} event = await delete_trigger.run().asend(None) - mock_hook.return_value.is_agent_engine_deleted.assert_called_once_with( - project_id=GCP_PROJECT, + mock_hook.return_value.get_agent_engine_operation.assert_called_once_with( location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, + operation_name=OPERATION_NAME, ) assert event == TriggerEvent( { @@ -84,7 +86,7 @@ async def test_run_loop_return_success_event(self, mock_hook, delete_trigger): @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) async def test_run_loop_return_timeout_event(self, mock_hook, mock_sleep, delete_trigger): delete_trigger.timeout = -1 - mock_hook.return_value.is_agent_engine_deleted.return_value = False + mock_hook.return_value.get_agent_engine_operation.return_value = {"done": False} event = await delete_trigger.run().asend(None) @@ -100,7 +102,7 @@ async def test_run_loop_return_timeout_event(self, mock_hook, mock_sleep, delete @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) async def test_run_loop_return_error_event(self, mock_hook, delete_trigger): - mock_hook.return_value.is_agent_engine_deleted.side_effect = RuntimeError("boom") + mock_hook.return_value.get_agent_engine_operation.side_effect = RuntimeError("boom") event = await delete_trigger.run().asend(None) @@ -111,3 +113,21 @@ async def test_run_loop_return_error_event(self, mock_hook, delete_trigger): "agent_engine_id": AGENT_ENGINE_ID, } ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) + async def test_run_loop_return_operation_error_event(self, mock_hook, delete_trigger): + mock_hook.return_value.get_agent_engine_operation.return_value = { + "done": True, + "error": {"message": "boom"}, + } + + event = await delete_trigger.run().asend(None) + + assert event == TriggerEvent( + { + "status": "error", + "message": "{'message': 'boom'}", + "agent_engine_id": AGENT_ENGINE_ID, + } + ) From a0b0054c1e7ca3cbfeaad11c18a4834f770deeb9 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Thu, 18 Jun 2026 20:51:28 -0300 Subject: [PATCH 15/32] Remove Agent Engine delete resource polling --- .../cloud/hooks/vertex_ai/agent_engine.py | 24 --------- .../hooks/vertex_ai/test_agent_engine.py | 54 ------------------- 2 files changed, 78 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index 78be2ccc3b7fc..a4bd583aec4c9 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -27,7 +27,6 @@ import google.auth.transport.requests from asgiref.sync import sync_to_async from google.genai._api_client import HttpOptions -from google.genai.errors import ClientError from vertexai import Client from airflow.providers.google.common.hooks.base_google import ( @@ -209,20 +208,6 @@ def delete_agent_engine( name = self.build_agent_engine_name(project_id, location, agent_engine_id) return client.delete(name=name, force=force, config=config) - def is_agent_engine_deleted(self, project_id: str, location: str, agent_engine_id: str) -> bool: - """Return whether an Agent Engine no longer exists.""" - try: - self.get_agent_engine( - project_id=project_id, - location=location, - agent_engine_id=agent_engine_id, - ) - except ClientError as err: - if getattr(err, "code", None) == 404: - return True - raise - return False - def get_agent_engine_operation(self, location: str, operation_name: str) -> dict[str, Any]: """Return a Vertex AI Agent Engine long-running operation.""" url = ( @@ -282,15 +267,6 @@ def __init__( **kwargs, ) - async def is_agent_engine_deleted(self, project_id: str, location: str, agent_engine_id: str) -> bool: - """Return whether an Agent Engine no longer exists.""" - sync_hook = await self.get_sync_hook() - return await sync_to_async(sync_hook.is_agent_engine_deleted)( - project_id=project_id, - location=location, - agent_engine_id=agent_engine_id, - ) - async def get_agent_engine_operation(self, location: str, operation_name: str) -> dict[str, Any]: """Return a Vertex AI Agent Engine long-running operation.""" sync_hook = await self.get_sync_hook() diff --git a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py index 0594e56892ef1..cdb8a111e2108 100644 --- a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py @@ -21,7 +21,6 @@ from unittest import mock import pytest -from google.genai.errors import ClientError from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineHook @@ -288,56 +287,3 @@ def test_wait_for_agent_engine_operation_times_out(self, mock_get_operation, moc ) mock_sleep.assert_not_called() - - @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine"), autospec=True) - def test_is_agent_engine_deleted_returns_false_when_resource_exists(self, mock_get_agent_engine): - assert not self.hook.is_agent_engine_deleted( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - ) - mock_get_agent_engine.assert_called_once_with( - self.hook, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - ) - - @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine"), autospec=True) - def test_is_agent_engine_deleted_returns_true_on_404(self, mock_get_agent_engine): - mock_get_agent_engine.side_effect = ClientError(code=404, response_json={"error": "not found"}) - - assert self.hook.is_agent_engine_deleted( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - ) - - @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine"), autospec=True) - def test_is_agent_engine_deleted_reraises_non_404(self, mock_get_agent_engine): - mock_get_agent_engine.side_effect = ClientError(code=500, response_json={"error": "server error"}) - - with pytest.raises(ClientError) as err: - self.hook.is_agent_engine_deleted( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - ) - - assert err.value.code == 500 - - @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine"), autospec=True) - def test_is_agent_engine_deleted_reraises_non_404_with_404_in_message(self, mock_get_agent_engine): - mock_get_agent_engine.side_effect = ClientError( - code=500, - response_json={"error": "server error for resource 404"}, - ) - - with pytest.raises(ClientError) as err: - self.hook.is_agent_engine_deleted( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - ) - - assert err.value.code == 500 From 44b5f2ce9ee850dccd9194fb32363363c71bee45 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Fri, 19 Jun 2026 07:03:48 -0300 Subject: [PATCH 16/32] Use public Agent Engine query job API --- .../cloud/hooks/vertex_ai/agent_engine.py | 48 +------- .../cloud/operators/vertex_ai/agent_engine.py | 20 ++-- .../example_vertex_ai_agent_engine.py | 21 +++- .../hooks/vertex_ai/test_agent_engine.py | 106 +----------------- .../operators/vertex_ai/test_agent_engine.py | 12 +- 5 files changed, 45 insertions(+), 162 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index a4bd583aec4c9..8eb5d47c663d0 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -19,14 +19,12 @@ from __future__ import annotations -import json import time from collections.abc import Sequence from typing import TYPE_CHECKING, Any import google.auth.transport.requests from asgiref.sync import sync_to_async -from google.genai._api_client import HttpOptions from vertexai import Client from airflow.providers.google.common.hooks.base_google import ( @@ -111,55 +109,21 @@ def query_agent_engine( self, location: str, agent_engine_id: str, - config: Any | None = None, - request_timeout: float | None = None, + config: types.RunQueryJobAgentEngineConfigOrDict | None = None, project_id: str = PROVIDE_PROJECT_ID, - ) -> Any: + ) -> types.RunQueryJobResult: """ - Query an Agent Engine. + Run a query job on an Agent Engine. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param agent_engine_id: Required. The Agent Engine ID. - :param config: Optional. Configuration for the query request (``class_method``, ``input``). - :param request_timeout: Optional. Timeout in seconds for the HTTP request. Defaults to no timeout. + :param config: Optional. Configuration for the query job (``query``, ``output_gcs_uri``). :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project configured in the connection. """ - # Use the SDK's _api_client.request() directly rather than the SDK's run_query_job - # (requires GCS) or _query (private method; triggers a Pydantic parsing bug in - # google-genai 2.8.0 when the response output type is Any). Calling request() bypasses - # Pydantic parsing while still letting the SDK handle URL construction and auth. - # Replace with a public synchronous query API when available; tracked at - # https://github.com/apache/airflow/issues/68605 - cfg = config if isinstance(config, dict) else {} - body: dict[str, Any] = {"classMethod": cfg.get("class_method", "query")} - if "input" in cfg: - input_val = cfg["input"] - if isinstance(input_val, str): - try: - input_val = json.loads(input_val) - except json.JSONDecodeError as err: - raise ValueError("Agent Engine query input must be valid JSON.") from err - if not isinstance(input_val, dict): - raise ValueError("Agent Engine query input must be a JSON object.") - body["input"] = input_val - - sdk_client = self.get_agent_engine_client(project_id=project_id, location=location) - http_options = HttpOptions( - timeout=int(request_timeout * 1000) if request_timeout is not None else None - ) + client = self.get_agent_engine_client(project_id=project_id, location=location) name = self.build_agent_engine_name(project_id, location, agent_engine_id) - api_client = getattr(sdk_client, "_api_client", None) - request = getattr(api_client, "request", None) - if request is None: - raise RuntimeError( - "The Vertex AI Agent Engine SDK no longer exposes _api_client.request. " - "QueryAgentEngineOperator must be updated to use a supported synchronous query API." - ) - response = request("post", f"{name}:query", body, http_options) - data = {} if not response.body else json.loads(response.body) - output = data.get("output") - return output if output is not None else data + return client.run_query_job(name=name, config=config) @GoogleBaseHook.fallback_to_default_project_id def update_agent_engine( diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py index dbfbf3b08ffd4..5f8ea6c2697bb 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -164,13 +164,12 @@ def execute(self, context: Context) -> dict[str, Any]: class QueryAgentEngineOperator(GoogleCloudBaseOperator): """ - Query a Vertex AI Agent Engine. + Run a query job on a Vertex AI Agent Engine. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param agent_engine_id: Required. The Agent Engine ID. - :param config: Optional. Configuration for the query request (``class_method``, ``input``). - :param request_timeout: Optional. Timeout in seconds for the HTTP request. Defaults to no timeout. + :param config: Optional. Configuration for the query job (``query``, ``output_gcs_uri``). :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials. """ @@ -190,8 +189,7 @@ def __init__( project_id: str, location: str, agent_engine_id: str, - config: Any | None = None, - request_timeout: float | None = None, + config: types.RunQueryJobAgentEngineConfigOrDict | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, @@ -201,7 +199,6 @@ def __init__( self.location = location self.agent_engine_id = agent_engine_id self.config = config - self.request_timeout = request_timeout self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain @@ -212,17 +209,16 @@ def hook(self) -> AgentEngineHook: impersonation_chain=self.impersonation_chain, ) - def execute(self, context: Context) -> Any: - self.log.info("Querying Agent Engine %s.", self.agent_engine_id) - result = self.hook.query_agent_engine( + def execute(self, context: Context) -> dict[str, Any]: + self.log.info("Running query job on Agent Engine %s.", self.agent_engine_id) + query_job = self.hook.query_agent_engine( project_id=self.project_id, location=self.location, agent_engine_id=self.agent_engine_id, config=self.config, - request_timeout=self.request_timeout, ) - self.log.info("Agent Engine %s was queried.", self.agent_engine_id) - return result + self.log.info("Query job was started on Agent Engine %s.", self.agent_engine_id) + return _serialize_value(query_job) class UpdateAgentEngineOperator(GoogleCloudBaseOperator): diff --git a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py index f2d33d27973e5..5968bbbe57c02 100644 --- a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py +++ b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py @@ -25,6 +25,7 @@ import os from datetime import datetime +from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator from airflow.providers.google.cloud.operators.vertex_ai.agent_engine import ( CreateAgentEngineOperator, DeleteAgentEngineOperator, @@ -50,11 +51,12 @@ ) AGENT_ENGINE_ID = "{{ task_instance.xcom_pull(task_ids='create_agent_engine')['name'].split('/')[-1] }}" +BUCKET_NAME = f"bucket-{DAG_ID}-{ENV_ID}".replace("_", "-") DISPLAY_NAME = f"airflow-agent-engine-{ENV_ID}" QUERY_CONFIG = { - "class_method": "query", - "input": {"prompt": "Respond with a short acknowledgement."}, + "query": "Respond with a short acknowledgement.", + "output_gcs_uri": f"gs://{BUCKET_NAME}/query-output/", } with DAG( @@ -96,6 +98,12 @@ ) # [END how_to_cloud_vertex_ai_get_agent_engine_operator] + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=BUCKET_NAME, + project_id=PROJECT_ID, + ) + # [START how_to_cloud_vertex_ai_query_agent_engine_operator] query_agent_engine = QueryAgentEngineOperator( task_id="query_agent_engine", @@ -131,12 +139,21 @@ ) # [END how_to_cloud_vertex_ai_delete_agent_engine_operator] + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", + bucket_name=BUCKET_NAME, + force=True, + trigger_rule=TriggerRule.ALL_DONE, + ) + ( create_agent_engine >> get_agent_engine + >> create_bucket >> query_agent_engine >> update_agent_engine >> delete_agent_engine + >> delete_bucket ) # ### Everything below this line is not part of example ### diff --git a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py index cdb8a111e2108..f59f9e7cf7ed1 100644 --- a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import json from unittest import mock import pytest @@ -36,7 +35,7 @@ AGENT_ENGINE_NAME = "projects/test-project/locations/us-central1/reasoningEngines/123" OPERATION_NAME = "projects/test-project/locations/us-central1/operations/delete-123" CONFIG = {"display_name": "test-agent-engine"} -QUERY_CONFIG = {"class_method": "query", "input": {"prompt": "hello"}} +QUERY_CONFIG = {"query": "hello", "output_gcs_uri": "gs://test-bucket/query-output/"} class TestAgentEngineHookWithDefaultProjectId: @@ -87,10 +86,6 @@ def test_get_agent_engine(self, mock_get_client): @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) def test_query_agent_engine(self, mock_get_client): - mock_get_client.return_value._api_client.request.return_value.body = json.dumps( - {"output": {"answer": "hello"}} - ) - result = self.hook.query_agent_engine( project_id=GCP_PROJECT, location=GCP_LOCATION, @@ -98,104 +93,11 @@ def test_query_agent_engine(self, mock_get_client): config=QUERY_CONFIG, ) - mock_get_client.return_value._api_client.request.assert_called_once_with( - "post", - f"{AGENT_ENGINE_NAME}:query", - {"classMethod": "query", "input": {"prompt": "hello"}}, - mock.ANY, - ) - assert result == {"answer": "hello"} - - @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) - def test_query_agent_engine_returns_full_response_when_output_missing(self, mock_get_client): - full_response = {"someOtherField": "value"} - mock_get_client.return_value._api_client.request.return_value.body = json.dumps(full_response) - - result = self.hook.query_agent_engine( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - config=QUERY_CONFIG, - ) - - assert result == full_response - - @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) - def test_query_agent_engine_returns_full_response_when_output_is_none(self, mock_get_client): - full_response = {"output": None} - mock_get_client.return_value._api_client.request.return_value.body = json.dumps(full_response) - - result = self.hook.query_agent_engine( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, + mock_get_client.return_value.run_query_job.assert_called_once_with( + name=AGENT_ENGINE_NAME, config=QUERY_CONFIG, ) - - assert result == full_response - - @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) - def test_query_agent_engine_raises_when_sdk_request_helper_is_missing(self, mock_get_client): - del mock_get_client.return_value._api_client.request - - with pytest.raises( - RuntimeError, - match="The Vertex AI Agent Engine SDK no longer exposes _api_client.request", - ): - self.hook.query_agent_engine( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - config=QUERY_CONFIG, - ) - - @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) - def test_query_agent_engine_parses_json_string_input(self, mock_get_client): - mock_get_client.return_value._api_client.request.return_value.body = json.dumps( - {"output": {"answer": "hello"}} - ) - - result = self.hook.query_agent_engine( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - config={"class_method": "query", "input": '{"prompt": "hello"}'}, - ) - - mock_get_client.return_value._api_client.request.assert_called_once_with( - "post", - f"{AGENT_ENGINE_NAME}:query", - {"classMethod": "query", "input": {"prompt": "hello"}}, - mock.ANY, - ) - assert result == {"answer": "hello"} - - def test_query_agent_engine_rejects_invalid_json_string_input(self): - with pytest.raises(ValueError, match="Agent Engine query input must be valid JSON."): - self.hook.query_agent_engine( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - config={"class_method": "query", "input": "not valid json"}, - ) - - @pytest.mark.parametrize( - "input_value", - [ - '"test string"', - '["prompt", "hello"]', - 1, - ["prompt", "hello"], - ], - ) - def test_query_agent_engine_rejects_non_object_input(self, input_value): - with pytest.raises(ValueError, match="Agent Engine query input must be a JSON object."): - self.hook.query_agent_engine( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - config={"class_method": "query", "input": input_value}, - ) + assert result == mock_get_client.return_value.run_query_job.return_value @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) def test_update_agent_engine(self, mock_get_client): diff --git a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py index 46e774b5d65f9..1696498e3f85d 100644 --- a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py @@ -40,7 +40,7 @@ AGENT_ENGINE_ID = "123" AGENT_ENGINE_NAME = "projects/test-project/locations/us-central1/reasoningEngines/123" CONFIG = {"display_name": "test-agent-engine"} -QUERY_CONFIG = {"class_method": "query", "input": {"prompt": "hello"}} +QUERY_CONFIG = {"query": "hello", "output_gcs_uri": "gs://test-bucket/query-output/"} OPERATION = {"name": "operations/delete-123", "done": False} @@ -124,7 +124,12 @@ def test_execute(self, mock_hook, context): class TestQueryAgentEngineOperator: @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) def test_execute(self, mock_hook, context): - mock_hook.return_value.query_agent_engine.return_value = {"output": "hello"} + result_payload = { + "job_name": "operations/query-123", + "input_gcs_uri": "gs://test-bucket/query-output/input.json", + "output_gcs_uri": "gs://test-bucket/query-output/output.json", + } + mock_hook.return_value.query_agent_engine.return_value = FakeModel(result_payload) op = QueryAgentEngineOperator( task_id=TASK_ID, project_id=GCP_PROJECT, @@ -142,9 +147,8 @@ def test_execute(self, mock_hook, context): location=GCP_LOCATION, agent_engine_id=AGENT_ENGINE_ID, config=QUERY_CONFIG, - request_timeout=None, ) - assert result == {"output": "hello"} + assert result == result_payload class TestUpdateAgentEngineOperator: From 4dce0a06189cd785d530ed36a0d0a6085d7117a2 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Fri, 19 Jun 2026 07:29:48 -0300 Subject: [PATCH 17/32] Document Agent Engine query job config --- providers/google/docs/operators/cloud/vertex_ai.rst | 6 +++++- .../providers/google/cloud/hooks/vertex_ai/agent_engine.py | 4 ++-- .../google/cloud/operators/vertex_ai/agent_engine.py | 4 ++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/providers/google/docs/operators/cloud/vertex_ai.rst b/providers/google/docs/operators/cloud/vertex_ai.rst index 2a82edd9cd63b..58bf393300c96 100644 --- a/providers/google/docs/operators/cloud/vertex_ai.rst +++ b/providers/google/docs/operators/cloud/vertex_ai.rst @@ -47,8 +47,12 @@ To get an Agent Engine you can use :start-after: [START how_to_cloud_vertex_ai_get_agent_engine_operator] :end-before: [END how_to_cloud_vertex_ai_get_agent_engine_operator] -To query an Agent Engine you can use +To run a query job on an Agent Engine you can use :class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.QueryAgentEngineOperator`. +The operator uses the public ``run_query_job`` SDK method. The ``config`` parameter +is required and must include ``query`` and ``output_gcs_uri``. The SDK writes query +input and output through Google Cloud Storage and the operator returns the serialized +query job result metadata, including the job name and GCS URIs. .. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py :language: python diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index 8eb5d47c663d0..5cf73034d2c00 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -109,7 +109,7 @@ def query_agent_engine( self, location: str, agent_engine_id: str, - config: types.RunQueryJobAgentEngineConfigOrDict | None = None, + config: types.RunQueryJobAgentEngineConfigOrDict, project_id: str = PROVIDE_PROJECT_ID, ) -> types.RunQueryJobResult: """ @@ -117,7 +117,7 @@ def query_agent_engine( :param location: Required. The ID of the Google Cloud location that the service belongs to. :param agent_engine_id: Required. The Agent Engine ID. - :param config: Optional. Configuration for the query job (``query``, ``output_gcs_uri``). + :param config: Required. Configuration for the query job (``query``, ``output_gcs_uri``). :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project configured in the connection. """ diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py index 5f8ea6c2697bb..5bd34c819ee21 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -169,7 +169,7 @@ class QueryAgentEngineOperator(GoogleCloudBaseOperator): :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param agent_engine_id: Required. The Agent Engine ID. - :param config: Optional. Configuration for the query job (``query``, ``output_gcs_uri``). + :param config: Required. Configuration for the query job (``query``, ``output_gcs_uri``). :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials. """ @@ -189,7 +189,7 @@ def __init__( project_id: str, location: str, agent_engine_id: str, - config: types.RunQueryJobAgentEngineConfigOrDict | None = None, + config: types.RunQueryJobAgentEngineConfigOrDict, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, From f5e3278fd2f7f95d0943ecc9790751c56926d44d Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Fri, 19 Jun 2026 08:02:31 -0300 Subject: [PATCH 18/32] Add CheckQueryAgentEngineOperator and fix query operator Introduces CheckQueryAgentEngineOperator (with deferrable support via AgentEngineQueryJobTrigger) to poll the status of a batch query job started by QueryAgentEngineOperator. Switches QueryAgentEngineOperator from a custom REST call to the SDK's public run_query_job() method, and removes the now-unused internal polling helpers. --- .../google/docs/operators/cloud/vertex_ai.rst | 16 ++- .../cloud/hooks/vertex_ai/agent_engine.py | 75 +++++++++++ .../cloud/operators/vertex_ai/agent_engine.py | 101 +++++++++++++- .../google/cloud/triggers/vertex_ai.py | 111 ++++++++++++++++ .../example_vertex_ai_agent_engine.py | 14 ++ .../hooks/vertex_ai/test_agent_engine.py | 69 ++++++++++ .../operators/vertex_ai/test_agent_engine.py | 118 +++++++++++++++++ .../triggers/test_vertex_ai_agent_engine.py | 124 +++++++++++++++++- 8 files changed, 625 insertions(+), 3 deletions(-) diff --git a/providers/google/docs/operators/cloud/vertex_ai.rst b/providers/google/docs/operators/cloud/vertex_ai.rst index 58bf393300c96..510eb82e4a348 100644 --- a/providers/google/docs/operators/cloud/vertex_ai.rst +++ b/providers/google/docs/operators/cloud/vertex_ai.rst @@ -52,7 +52,8 @@ To run a query job on an Agent Engine you can use The operator uses the public ``run_query_job`` SDK method. The ``config`` parameter is required and must include ``query`` and ``output_gcs_uri``. The SDK writes query input and output through Google Cloud Storage and the operator returns the serialized -query job result metadata, including the job name and GCS URIs. +query job result metadata, including the job name and GCS URIs. This operator starts +the query job and does not wait for the query job output. .. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py :language: python @@ -60,6 +61,19 @@ query job result metadata, including the job name and GCS URIs. :start-after: [START how_to_cloud_vertex_ai_query_agent_engine_operator] :end-before: [END how_to_cloud_vertex_ai_query_agent_engine_operator] +To wait for a query job and retrieve its output you can use +:class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.CheckQueryAgentEngineOperator`. +The operator uses the public ``check_query_job`` SDK method. The ``config`` parameter +is optional. Set ``retrieve_result`` to ``True`` in ``config`` to return the query job +result from Google Cloud Storage. The same operation can be performed in the +deferrable mode. + +.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_check_query_agent_engine_operator] + :end-before: [END how_to_cloud_vertex_ai_check_query_agent_engine_operator] + To update an Agent Engine you can use :class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.UpdateAgentEngineOperator`. diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index 5cf73034d2c00..3717a7a0804ae 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -125,6 +125,65 @@ def query_agent_engine( name = self.build_agent_engine_name(project_id, location, agent_engine_id) return client.run_query_job(name=name, config=config) + @GoogleBaseHook.fallback_to_default_project_id + def check_query_agent_engine_job( + self, + location: str, + operation_name: str, + config: types.CheckQueryJobAgentEngineConfigOrDict | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> types.CheckQueryJobResult: + """ + Check a query job on an Agent Engine. + + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param operation_name: Required. The query job operation name. + :param config: Optional. Configuration for checking the query job. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. + """ + client = self.get_agent_engine_client(project_id=project_id, location=location) + return client.check_query_job(name=operation_name, config=config) + + @GoogleBaseHook.fallback_to_default_project_id + def wait_for_query_agent_engine_job( + self, + location: str, + operation_name: str, + config: types.CheckQueryJobAgentEngineConfigOrDict | None = None, + poll_interval: float = 30, + timeout: float | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> types.CheckQueryJobResult: + """ + Wait until an Agent Engine query job completes. + + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param operation_name: Required. The query job operation name. + :param config: Optional. Configuration for checking the query job. + :param poll_interval: Time, in seconds, to wait between checks. + :param timeout: Optional timeout, in seconds. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. + """ + start_time = time.monotonic() + while True: + query_job = self.check_query_agent_engine_job( + project_id=project_id, + location=location, + operation_name=operation_name, + config=config, + ) + status = getattr(query_job, "status", None) + if status == "SUCCESS": + return query_job + if status == "FAILED": + raise RuntimeError(f"Agent Engine query job {operation_name} failed.") + if timeout is not None and time.monotonic() - start_time > timeout: + raise TimeoutError(f"Timed out waiting for Agent Engine query job {operation_name}") + self.log.info("Waiting for Agent Engine query job %s to complete.", operation_name) + time.sleep(poll_interval) + @GoogleBaseHook.fallback_to_default_project_id def update_agent_engine( self, @@ -238,3 +297,19 @@ async def get_agent_engine_operation(self, location: str, operation_name: str) - location=location, operation_name=operation_name, ) + + async def check_query_agent_engine_job( + self, + location: str, + operation_name: str, + config: types.CheckQueryJobAgentEngineConfigOrDict | None = None, + project_id: str = PROVIDE_PROJECT_ID, + ) -> types.CheckQueryJobResult: + """Check a query job on an Agent Engine.""" + sync_hook = await self.get_sync_hook() + return await sync_to_async(sync_hook.check_query_agent_engine_job)( + project_id=project_id, + location=location, + operation_name=operation_name, + config=config, + ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py index 5bd34c819ee21..885ff2c11378d 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -26,7 +26,10 @@ from airflow.providers.common.compat.sdk import conf from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineHook from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator -from airflow.providers.google.cloud.triggers.vertex_ai import AgentEngineDeleteTrigger +from airflow.providers.google.cloud.triggers.vertex_ai import ( + AgentEngineDeleteTrigger, + AgentEngineQueryJobTrigger, +) if TYPE_CHECKING: from vertexai._genai import types @@ -221,6 +224,102 @@ def execute(self, context: Context) -> dict[str, Any]: return _serialize_value(query_job) +class CheckQueryAgentEngineOperator(GoogleCloudBaseOperator): + """ + Check a query job on a Vertex AI Agent Engine. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param operation_name: Required. The query job operation name (e.g. from the ``job_name`` field of the result of ``QueryAgentEngineOperator``). + :param config: Optional. Configuration for checking the query job. + :param poll_interval: Time, in seconds, to wait between checks. + :param timeout: Optional timeout, in seconds. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term credentials. + :param deferrable: Run operator in the deferrable mode. + """ + + template_fields = ( + "project_id", + "location", + "operation_name", + "config", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + project_id: str, + location: str, + operation_name: str, + config: types.CheckQueryJobAgentEngineConfigOrDict | None = None, + poll_interval: float = 30, + timeout: float | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.operation_name = operation_name + self.config = config + self.poll_interval = poll_interval + self.timeout = timeout + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.deferrable = deferrable + + @cached_property + def hook(self) -> AgentEngineHook: + return AgentEngineHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + def execute(self, context: Context) -> dict[str, Any]: + self.log.info("Checking Agent Engine query job %s.", self.operation_name) + if self.deferrable: + self.defer( + trigger=AgentEngineQueryJobTrigger( + project_id=self.project_id, + location=self.location, + operation_name=self.operation_name, + config=_serialize_value(self.config), + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + poll_interval=self.poll_interval, + timeout=self.timeout, + ), + method_name="execute_complete", + ) + + query_job = self.hook.wait_for_query_agent_engine_job( + project_id=self.project_id, + location=self.location, + operation_name=self.operation_name, + config=self.config, + poll_interval=self.poll_interval, + timeout=self.timeout, + ) + result = _serialize_value(query_job) + self.log.info("Agent Engine query job %s completed.", self.operation_name) + return result + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: + if event is None: + raise RuntimeError("No event received in trigger callback") + if event["status"] == "success": + self.log.info("Agent Engine query job completed.") + return event["query_job"] + if event["status"] == "timeout": + raise TimeoutError(event["message"]) + raise RuntimeError(event["message"]) + + class UpdateAgentEngineOperator(GoogleCloudBaseOperator): """ Update a Vertex AI Agent Engine. diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py index 2229fa143d92c..5b95cfe3eca34 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py @@ -44,6 +44,18 @@ from proto import Message +def _serialize_value(value: Any) -> Any: + if hasattr(value, "model_dump"): + return value.model_dump(mode="json") + if isinstance(value, dict): + return {key: _serialize_value(item) for key, item in value.items()} + if isinstance(value, list): + return [_serialize_value(item) for item in value] + if isinstance(value, tuple): + return tuple(_serialize_value(item) for item in value) + return value + + class BaseVertexAIJobTrigger(BaseTrigger): """ Base class for Vertex AI job triggers. @@ -237,6 +249,105 @@ async def run(self) -> AsyncIterator[TriggerEvent]: ) +class AgentEngineQueryJobTrigger(BaseTrigger): + """Trigger that waits until a Vertex AI Agent Engine query job completes.""" + + def __init__( + self, + project_id: str, + location: str, + operation_name: str, + config: dict[str, Any] | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + poll_interval: float = 30, + timeout: float | None = None, + ): + super().__init__() + self.project_id = project_id + self.location = location + self.operation_name = operation_name + self.config = config + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.poll_interval = poll_interval + self.timeout = timeout + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineQueryJobTrigger", + { + "project_id": self.project_id, + "location": self.location, + "operation_name": self.operation_name, + "config": _serialize_value(self.config), + "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, + "poll_interval": self.poll_interval, + "timeout": self.timeout, + }, + ) + + @cached_property + def async_hook(self) -> AgentEngineAsyncHook: + return AgentEngineAsyncHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + start_time = time.monotonic() + try: + while True: + query_job = await self.async_hook.check_query_agent_engine_job( + project_id=self.project_id, + location=self.location, + operation_name=self.operation_name, + config=self.config, + ) + status = getattr(query_job, "status", None) + if status == "SUCCESS": + yield TriggerEvent( + { + "status": "success", + "message": "Agent Engine query job completed", + "query_job": _serialize_value(query_job), + } + ) + return + if status == "FAILED": + yield TriggerEvent( + { + "status": "error", + "message": f"Agent Engine query job {self.operation_name} failed.", + "query_job": _serialize_value(query_job), + } + ) + return + + if self.timeout is not None and time.monotonic() - start_time >= self.timeout: + yield TriggerEvent( + { + "status": "timeout", + "message": f"Timed out waiting for Agent Engine query job {self.operation_name}", + "operation_name": self.operation_name, + } + ) + return + + self.log.info("Waiting for Agent Engine query job %s to complete.", self.operation_name) + await asyncio.sleep(self.poll_interval) + except Exception as err: + self.log.exception("Exception occurred while waiting for Agent Engine query job.") + yield TriggerEvent( + { + "status": "error", + "message": str(err), + "operation_name": self.operation_name, + } + ) + + class CreateHyperparameterTuningJobTrigger(BaseVertexAIJobTrigger): """CreateHyperparameterTuningJobTrigger run on the trigger worker to perform create operation.""" diff --git a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py index 5968bbbe57c02..ae4593cab91c0 100644 --- a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py +++ b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py @@ -27,6 +27,7 @@ from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator from airflow.providers.google.cloud.operators.vertex_ai.agent_engine import ( + CheckQueryAgentEngineOperator, CreateAgentEngineOperator, DeleteAgentEngineOperator, GetAgentEngineOperator, @@ -51,6 +52,7 @@ ) AGENT_ENGINE_ID = "{{ task_instance.xcom_pull(task_ids='create_agent_engine')['name'].split('/')[-1] }}" +QUERY_OPERATION_NAME = "{{ task_instance.xcom_pull(task_ids='query_agent_engine')['job_name'] }}" BUCKET_NAME = f"bucket-{DAG_ID}-{ENV_ID}".replace("_", "-") DISPLAY_NAME = f"airflow-agent-engine-{ENV_ID}" @@ -114,6 +116,17 @@ ) # [END how_to_cloud_vertex_ai_query_agent_engine_operator] + # [START how_to_cloud_vertex_ai_check_query_agent_engine_operator] + check_query_agent_engine = CheckQueryAgentEngineOperator( + task_id="check_query_agent_engine", + project_id=PROJECT_ID, + location=LOCATION, + operation_name=QUERY_OPERATION_NAME, + config={"retrieve_result": True}, + deferrable=True, + ) + # [END how_to_cloud_vertex_ai_check_query_agent_engine_operator] + # [START how_to_cloud_vertex_ai_update_agent_engine_operator] update_agent_engine = UpdateAgentEngineOperator( task_id="update_agent_engine", @@ -151,6 +164,7 @@ >> get_agent_engine >> create_bucket >> query_agent_engine + >> check_query_agent_engine >> update_agent_engine >> delete_agent_engine >> delete_bucket diff --git a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py index f59f9e7cf7ed1..b3fd3e20676fc 100644 --- a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py @@ -34,8 +34,10 @@ AGENT_ENGINE_ID = "123" AGENT_ENGINE_NAME = "projects/test-project/locations/us-central1/reasoningEngines/123" OPERATION_NAME = "projects/test-project/locations/us-central1/operations/delete-123" +QUERY_OPERATION_NAME = "projects/test-project/locations/us-central1/operations/query-123" CONFIG = {"display_name": "test-agent-engine"} QUERY_CONFIG = {"query": "hello", "output_gcs_uri": "gs://test-bucket/query-output/"} +CHECK_QUERY_CONFIG = {"retrieve_result": True} class TestAgentEngineHookWithDefaultProjectId: @@ -99,6 +101,73 @@ def test_query_agent_engine(self, mock_get_client): ) assert result == mock_get_client.return_value.run_query_job.return_value + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_check_query_agent_engine_job(self, mock_get_client): + result = self.hook.check_query_agent_engine_job( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + ) + + mock_get_client.return_value.check_query_job.assert_called_once_with( + name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + ) + assert result == mock_get_client.return_value.check_query_job.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"), autospec=True) + def test_wait_for_query_agent_engine_job_returns_when_successful(self, mock_check_query_job): + mock_check_query_job.return_value.status = "SUCCESS" + + result = self.hook.wait_for_query_agent_engine_job( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + ) + + mock_check_query_job.assert_called_once_with( + self.hook, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + ) + assert result == mock_check_query_job.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"), autospec=True) + def test_wait_for_query_agent_engine_job_raises_on_failed_status(self, mock_check_query_job): + mock_check_query_job.return_value.status = "FAILED" + + with pytest.raises(RuntimeError, match="Agent Engine query job .* failed"): + self.hook.wait_for_query_agent_engine_job( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + ) + + @mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("time.monotonic"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"), autospec=True) + def test_wait_for_query_agent_engine_job_times_out( + self, mock_check_query_job, mock_monotonic, mock_sleep + ): + mock_check_query_job.return_value.status = "RUNNING" + mock_monotonic.side_effect = [1, 3] + + with pytest.raises(TimeoutError, match="Timed out waiting for Agent Engine query job"): + self.hook.wait_for_query_agent_engine_job( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + timeout=1, + ) + + mock_sleep.assert_not_called() + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) def test_update_agent_engine(self, mock_get_client): result = self.hook.update_agent_engine( diff --git a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py index 1696498e3f85d..9de7f6da13099 100644 --- a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py @@ -23,6 +23,7 @@ from airflow.providers.common.compat.sdk import TaskDeferred from airflow.providers.google.cloud.operators.vertex_ai.agent_engine import ( + CheckQueryAgentEngineOperator, CreateAgentEngineOperator, DeleteAgentEngineOperator, GetAgentEngineOperator, @@ -41,7 +42,9 @@ AGENT_ENGINE_NAME = "projects/test-project/locations/us-central1/reasoningEngines/123" CONFIG = {"display_name": "test-agent-engine"} QUERY_CONFIG = {"query": "hello", "output_gcs_uri": "gs://test-bucket/query-output/"} +CHECK_QUERY_CONFIG = {"retrieve_result": True} OPERATION = {"name": "operations/delete-123", "done": False} +QUERY_OPERATION_NAME = "operations/query-123" class FakeModel: @@ -151,6 +154,121 @@ def test_execute(self, mock_hook, context): assert result == result_payload +class TestCheckQueryAgentEngineOperator: + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute(self, mock_hook, context): + result_payload = { + "operation_name": QUERY_OPERATION_NAME, + "output_gcs_uri": "gs://test-bucket/query-output/output.json", + "status": "SUCCESS", + "result": "done", + } + mock_hook.return_value.wait_for_query_agent_engine_job.return_value = FakeModel(result_payload) + op = CheckQueryAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + poll_interval=1, + timeout=60, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + mock_hook.return_value.wait_for_query_agent_engine_job.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + poll_interval=1, + timeout=60, + ) + assert result == result_payload + + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineQueryJobTrigger"), autospec=True) + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_deferrable(self, mock_hook, mock_trigger, context): + op = CheckQueryAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + poll_interval=1, + timeout=60, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + deferrable=True, + ) + + with pytest.raises(TaskDeferred): + op.execute(context=context) + + mock_hook.return_value.wait_for_query_agent_engine_job.assert_not_called() + mock_trigger.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + poll_interval=1, + timeout=60, + ) + + def test_execute_complete_success(self, context): + op = CheckQueryAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + ) + query_job = {"operation_name": QUERY_OPERATION_NAME, "status": "SUCCESS"} + + result = op.execute_complete( + context=context, + event={"status": "success", "message": "done", "query_job": query_job}, + ) + + assert result == query_job + + def test_execute_complete_error(self, context): + op = CheckQueryAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + ) + + with pytest.raises(RuntimeError, match="boom"): + op.execute_complete(context=context, event={"status": "error", "message": "boom"}) + + def test_execute_complete_timeout(self, context): + op = CheckQueryAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + ) + + with pytest.raises(TimeoutError, match="timed out"): + op.execute_complete(context=context, event={"status": "timeout", "message": "timed out"}) + + def test_execute_complete_without_event(self, context): + op = CheckQueryAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + ) + + with pytest.raises(RuntimeError, match="No event received in trigger callback"): + op.execute_complete(context=context) + + class TestUpdateAgentEngineOperator: @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) def test_execute(self, mock_hook, context): diff --git a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py index 6fb122c6bd984..e62fedf0e5bae 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py @@ -21,7 +21,10 @@ import pytest -from airflow.providers.google.cloud.triggers.vertex_ai import AgentEngineDeleteTrigger +from airflow.providers.google.cloud.triggers.vertex_ai import ( + AgentEngineDeleteTrigger, + AgentEngineQueryJobTrigger, +) from airflow.triggers.base import TriggerEvent GCP_PROJECT = "test-project" @@ -30,6 +33,18 @@ IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] AGENT_ENGINE_ID = "123" OPERATION_NAME = "projects/test-project/locations/us-central1/operations/delete-123" +QUERY_OPERATION_NAME = "projects/test-project/locations/us-central1/operations/query-123" +CHECK_QUERY_CONFIG = {"retrieve_result": True} + + +class FakeModel: + def __init__(self, payload): + self.payload = payload + for key, value in payload.items(): + setattr(self, key, value) + + def model_dump(self, mode="json"): + return self.payload @pytest.fixture @@ -46,6 +61,20 @@ def delete_trigger(): ) +@pytest.fixture +def query_job_trigger(): + return AgentEngineQueryJobTrigger( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + poll_interval=1, + timeout=60, + ) + + class TestAgentEngineDeleteTrigger: def test_serialize(self, delete_trigger): assert delete_trigger.serialize() == ( @@ -131,3 +160,96 @@ async def test_run_loop_return_operation_error_event(self, mock_hook, delete_tri "agent_engine_id": AGENT_ENGINE_ID, } ) + + +class TestAgentEngineQueryJobTrigger: + def test_serialize(self, query_job_trigger): + assert query_job_trigger.serialize() == ( + "airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineQueryJobTrigger", + { + "project_id": GCP_PROJECT, + "location": GCP_LOCATION, + "operation_name": QUERY_OPERATION_NAME, + "config": CHECK_QUERY_CONFIG, + "gcp_conn_id": GCP_CONN_ID, + "impersonation_chain": IMPERSONATION_CHAIN, + "poll_interval": 1, + "timeout": 60, + }, + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) + async def test_run_loop_return_success_event(self, mock_hook, query_job_trigger): + query_job = { + "operation_name": QUERY_OPERATION_NAME, + "output_gcs_uri": "gs://test-bucket/query-output/output.json", + "status": "SUCCESS", + "result": "done", + } + mock_hook.return_value.check_query_agent_engine_job.return_value = FakeModel(query_job) + + event = await query_job_trigger.run().asend(None) + + mock_hook.return_value.check_query_agent_engine_job.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + ) + assert event == TriggerEvent( + { + "status": "success", + "message": "Agent Engine query job completed", + "query_job": query_job, + } + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) + async def test_run_loop_return_failed_event(self, mock_hook, query_job_trigger): + query_job = {"operation_name": QUERY_OPERATION_NAME, "status": "FAILED"} + mock_hook.return_value.check_query_agent_engine_job.return_value = FakeModel(query_job) + + event = await query_job_trigger.run().asend(None) + + assert event == TriggerEvent( + { + "status": "error", + "message": f"Agent Engine query job {QUERY_OPERATION_NAME} failed.", + "query_job": query_job, + } + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.asyncio.sleep", autospec=True) + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) + async def test_run_loop_return_timeout_event(self, mock_hook, mock_sleep, query_job_trigger): + query_job_trigger.timeout = -1 + mock_hook.return_value.check_query_agent_engine_job.return_value = FakeModel({"status": "RUNNING"}) + + event = await query_job_trigger.run().asend(None) + + mock_sleep.assert_not_called() + assert event == TriggerEvent( + { + "status": "timeout", + "message": f"Timed out waiting for Agent Engine query job {QUERY_OPERATION_NAME}", + "operation_name": QUERY_OPERATION_NAME, + } + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) + async def test_run_loop_return_error_event(self, mock_hook, query_job_trigger): + mock_hook.return_value.check_query_agent_engine_job.side_effect = RuntimeError("boom") + + event = await query_job_trigger.run().asend(None) + + assert event == TriggerEvent( + { + "status": "error", + "message": "boom", + "operation_name": QUERY_OPERATION_NAME, + } + ) From f577886c3665a871f606f878128b1357b0c7b2a4 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Fri, 19 Jun 2026 08:34:43 -0300 Subject: [PATCH 19/32] Fix _serialize_value duplication and unknown status handling in Agent Engine - Move _serialize_value from operators and triggers into the hook module so both import from a single source of truth - Warn on unknown query job status values in wait_for_query_agent_engine_job to surface unexpected API responses instead of silently looping - Add round-trip serialize test for AgentEngineQueryJobTrigger with a Pydantic model config to verify _serialize_value is applied on serialize() --- .../google/cloud/hooks/vertex_ai/agent_engine.py | 15 +++++++++++++++ .../cloud/operators/vertex_ai/agent_engine.py | 14 +------------- .../providers/google/cloud/triggers/vertex_ai.py | 14 +------------- .../cloud/triggers/test_vertex_ai_agent_engine.py | 13 +++++++++++++ 4 files changed, 30 insertions(+), 26 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index 3717a7a0804ae..9edf984b1f53a 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -37,6 +37,19 @@ from vertexai._genai import types +def _serialize_value(value: Any) -> Any: + """Recursively convert SDK model objects to JSON-serializable types.""" + if hasattr(value, "model_dump"): + return value.model_dump(mode="json") + if isinstance(value, dict): + return {key: _serialize_value(item) for key, item in value.items()} + if isinstance(value, list): + return [_serialize_value(item) for item in value] + if isinstance(value, tuple): + return tuple(_serialize_value(item) for item in value) + return value + + class AgentEngineHook(GoogleBaseHook): """Hook for Google Cloud Vertex AI Agent Engine APIs.""" @@ -179,6 +192,8 @@ def wait_for_query_agent_engine_job( return query_job if status == "FAILED": raise RuntimeError(f"Agent Engine query job {operation_name} failed.") + if status is not None: + self.log.warning("Unknown Agent Engine query job status: %s", status) if timeout is not None and time.monotonic() - start_time > timeout: raise TimeoutError(f"Timed out waiting for Agent Engine query job {operation_name}") self.log.info("Waiting for Agent Engine query job %s to complete.", operation_name) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py index 885ff2c11378d..dbc5fca471379 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any from airflow.providers.common.compat.sdk import conf -from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineHook +from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineHook, _serialize_value from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from airflow.providers.google.cloud.triggers.vertex_ai import ( AgentEngineDeleteTrigger, @@ -37,18 +37,6 @@ from airflow.providers.common.compat.sdk import Context -def _serialize_value(value: Any) -> Any: - if hasattr(value, "model_dump"): - return value.model_dump(mode="json") - if isinstance(value, dict): - return {key: _serialize_value(item) for key, item in value.items()} - if isinstance(value, list): - return [_serialize_value(item) for item in value] - if isinstance(value, tuple): - return tuple(_serialize_value(item) for item in value) - return value - - def _serialize_agent_engine(agent_engine: types.AgentEngine) -> dict[str, Any]: api_resource = getattr(agent_engine, "api_resource", None) if api_resource is not None: diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py index 5b95cfe3eca34..9bf201e43f029 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py @@ -31,7 +31,7 @@ ) from airflow.providers.common.compat.sdk import AirflowException -from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineAsyncHook +from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineAsyncHook, _serialize_value from airflow.providers.google.cloud.hooks.vertex_ai.batch_prediction_job import BatchPredictionJobAsyncHook from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobAsyncHook from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import ( @@ -44,18 +44,6 @@ from proto import Message -def _serialize_value(value: Any) -> Any: - if hasattr(value, "model_dump"): - return value.model_dump(mode="json") - if isinstance(value, dict): - return {key: _serialize_value(item) for key, item in value.items()} - if isinstance(value, list): - return [_serialize_value(item) for item in value] - if isinstance(value, tuple): - return tuple(_serialize_value(item) for item in value) - return value - - class BaseVertexAIJobTrigger(BaseTrigger): """ Base class for Vertex AI job triggers. diff --git a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py index e62fedf0e5bae..d15947c0d2f36 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py @@ -178,6 +178,19 @@ def test_serialize(self, query_job_trigger): }, ) + def test_serialize_with_pydantic_config(self): + pydantic_config = FakeModel(CHECK_QUERY_CONFIG) + trigger = AgentEngineQueryJobTrigger( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + config=pydantic_config, + gcp_conn_id=GCP_CONN_ID, + poll_interval=1, + ) + _, kwargs = trigger.serialize() + assert kwargs["config"] == CHECK_QUERY_CONFIG + @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) async def test_run_loop_return_success_event(self, mock_hook, query_job_trigger): From 587c5dee934b337d1fd0d0ebb77936052a1b3b39 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Fri, 19 Jun 2026 17:10:54 -0300 Subject: [PATCH 20/32] Fix timeout comparison to >= in Agent Engine hook polling loops --- .../providers/google/cloud/hooks/vertex_ai/agent_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index 9edf984b1f53a..7d2eb10b58d76 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -194,7 +194,7 @@ def wait_for_query_agent_engine_job( raise RuntimeError(f"Agent Engine query job {operation_name} failed.") if status is not None: self.log.warning("Unknown Agent Engine query job status: %s", status) - if timeout is not None and time.monotonic() - start_time > timeout: + if timeout is not None and time.monotonic() - start_time >= timeout: raise TimeoutError(f"Timed out waiting for Agent Engine query job {operation_name}") self.log.info("Waiting for Agent Engine query job %s to complete.", operation_name) time.sleep(poll_interval) @@ -282,7 +282,7 @@ def wait_for_agent_engine_operation( f"Agent Engine operation {operation_name} failed: {operation['error']}" ) return - if timeout is not None and time.monotonic() - start_time > timeout: + if timeout is not None and time.monotonic() - start_time >= timeout: raise TimeoutError(f"Timed out waiting for Agent Engine operation {operation_name}") self.log.info("Waiting for Agent Engine operation %s to complete.", operation_name) time.sleep(poll_interval) From 3f20a0ddf98367f0108c130e2602518f23f8eaa0 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Fri, 19 Jun 2026 17:59:44 -0300 Subject: [PATCH 21/32] Fix mypy errors in Agent Engine trigger and system test config types --- .../google/cloud/operators/vertex_ai/agent_engine.py | 2 +- .../airflow/providers/google/cloud/triggers/vertex_ai.py | 3 ++- .../cloud/vertex_ai/example_vertex_ai_agent_engine.py | 6 +++++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py index dbc5fca471379..64f437e978c95 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -276,7 +276,7 @@ def execute(self, context: Context) -> dict[str, Any]: project_id=self.project_id, location=self.location, operation_name=self.operation_name, - config=_serialize_value(self.config), + config=self.config, gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, poll_interval=self.poll_interval, diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py index 9bf201e43f029..fb1b5c694df9d 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py @@ -42,6 +42,7 @@ if TYPE_CHECKING: from proto import Message + from vertexai._genai import types as vertexai_types class BaseVertexAIJobTrigger(BaseTrigger): @@ -245,7 +246,7 @@ def __init__( project_id: str, location: str, operation_name: str, - config: dict[str, Any] | None = None, + config: vertexai_types.CheckQueryJobAgentEngineConfigOrDict | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, poll_interval: float = 30, diff --git a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py index ae4593cab91c0..99fcd2b41a9c7 100644 --- a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py +++ b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py @@ -24,6 +24,10 @@ import os from datetime import datetime +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from vertexai._genai import types as vertexai_types from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator from airflow.providers.google.cloud.operators.vertex_ai.agent_engine import ( @@ -56,7 +60,7 @@ BUCKET_NAME = f"bucket-{DAG_ID}-{ENV_ID}".replace("_", "-") DISPLAY_NAME = f"airflow-agent-engine-{ENV_ID}" -QUERY_CONFIG = { +QUERY_CONFIG: vertexai_types.RunQueryJobAgentEngineConfigDict = { "query": "Respond with a short acknowledgement.", "output_gcs_uri": f"gs://{BUCKET_NAME}/query-output/", } From 78d1e437c5e6ff5f8f2f14d06baa9893bf1e2c1f Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Sat, 20 Jun 2026 17:22:43 -0300 Subject: [PATCH 22/32] Handle Agent Engine polling edge cases --- .../cloud/hooks/vertex_ai/agent_engine.py | 6 +- .../cloud/operators/vertex_ai/agent_engine.py | 1 - .../google/cloud/triggers/vertex_ai.py | 40 +++---- .../hooks/vertex_ai/test_agent_engine.py | 106 +++++++++++++++++- .../operators/vertex_ai/test_agent_engine.py | 19 +++- .../triggers/test_vertex_ai_agent_engine.py | 77 ++++++++++++- 6 files changed, 218 insertions(+), 31 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index 7d2eb10b58d76..4ba8afe76c084 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -192,8 +192,10 @@ def wait_for_query_agent_engine_job( return query_job if status == "FAILED": raise RuntimeError(f"Agent Engine query job {operation_name} failed.") - if status is not None: - self.log.warning("Unknown Agent Engine query job status: %s", status) + if status not in (None, "RUNNING"): + raise RuntimeError( + f"Agent Engine query job {operation_name} completed with unexpected status {status}." + ) if timeout is not None and time.monotonic() - start_time >= timeout: raise TimeoutError(f"Timed out waiting for Agent Engine query job {operation_name}") self.log.info("Waiting for Agent Engine query job %s to complete.", operation_name) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py index 64f437e978c95..34db575fc1737 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -460,7 +460,6 @@ def execute(self, context: Context) -> dict[str, Any]: if self.deferrable: self.defer( trigger=AgentEngineDeleteTrigger( - project_id=self.project_id, location=self.location, agent_engine_id=self.agent_engine_id, gcp_conn_id=self.gcp_conn_id, diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py index fb1b5c694df9d..0f9b3b6984cc0 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py @@ -135,17 +135,15 @@ class AgentEngineDeleteTrigger(BaseTrigger): def __init__( self, - project_id: str, location: str, agent_engine_id: str, + operation_name: str, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, poll_interval: float = 30, timeout: float | None = None, - operation_name: str | None = None, ): super().__init__() - self.project_id = project_id self.location = location self.agent_engine_id = agent_engine_id self.gcp_conn_id = gcp_conn_id @@ -158,7 +156,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineDeleteTrigger", { - "project_id": self.project_id, "location": self.location, "agent_engine_id": self.agent_engine_id, "gcp_conn_id": self.gcp_conn_id, @@ -177,16 +174,6 @@ def async_hook(self) -> AgentEngineAsyncHook: ) async def run(self) -> AsyncIterator[TriggerEvent]: - if not self.operation_name: - yield TriggerEvent( - { - "status": "error", - "message": "Delete Agent Engine operation name is required.", - "agent_engine_id": self.agent_engine_id, - } - ) - return - start_time = time.monotonic() try: while True: @@ -232,7 +219,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: yield TriggerEvent( { "status": "error", - "message": str(err), + "message": f"Failed while polling Agent Engine deletion: {err}", "agent_engine_id": self.agent_engine_id, } ) @@ -295,12 +282,13 @@ async def run(self) -> AsyncIterator[TriggerEvent]: config=self.config, ) status = getattr(query_job, "status", None) + serialized_query_job = _serialize_value(query_job) if status == "SUCCESS": yield TriggerEvent( { "status": "success", "message": "Agent Engine query job completed", - "query_job": _serialize_value(query_job), + "query_job": serialized_query_job, } ) return @@ -309,7 +297,19 @@ async def run(self) -> AsyncIterator[TriggerEvent]: { "status": "error", "message": f"Agent Engine query job {self.operation_name} failed.", - "query_job": _serialize_value(query_job), + "query_job": serialized_query_job, + } + ) + return + if status not in (None, "RUNNING"): + yield TriggerEvent( + { + "status": "error", + "message": ( + f"Agent Engine query job {self.operation_name} completed with " + f"unexpected status {status}." + ), + "query_job": serialized_query_job, } ) return @@ -319,7 +319,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: { "status": "timeout", "message": f"Timed out waiting for Agent Engine query job {self.operation_name}", - "operation_name": self.operation_name, + "query_job": serialized_query_job, } ) return @@ -331,8 +331,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: yield TriggerEvent( { "status": "error", - "message": str(err), - "operation_name": self.operation_name, + "message": f"Failed while polling Agent Engine query job: {err}", + "query_job": {"operation_name": self.operation_name}, } ) diff --git a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py index b3fd3e20676fc..87c9984f03c3a 100644 --- a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py @@ -21,7 +21,7 @@ import pytest -from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineHook +from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineAsyncHook, AgentEngineHook from unit.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id @@ -136,6 +136,26 @@ def test_wait_for_query_agent_engine_job_returns_when_successful(self, mock_chec ) assert result == mock_check_query_job.return_value + @mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"), autospec=True) + def test_wait_for_query_agent_engine_job_polls_until_success(self, mock_check_query_job, mock_sleep): + running_job = mock.Mock(status="RUNNING") + success_job = mock.Mock(status="SUCCESS") + mock_check_query_job.side_effect = [running_job, running_job, success_job] + + result = self.hook.wait_for_query_agent_engine_job( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + poll_interval=10, + ) + + assert result is success_job + assert mock_check_query_job.call_count == 3 + assert mock_sleep.call_count == 2 + mock_sleep.assert_called_with(10) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"), autospec=True) def test_wait_for_query_agent_engine_job_raises_on_failed_status(self, mock_check_query_job): mock_check_query_job.return_value.status = "FAILED" @@ -148,6 +168,26 @@ def test_wait_for_query_agent_engine_job_raises_on_failed_status(self, mock_chec config=CHECK_QUERY_CONFIG, ) + @mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"), autospec=True) + def test_wait_for_query_agent_engine_job_raises_on_unexpected_status( + self, mock_check_query_job, mock_sleep + ): + mock_check_query_job.return_value.status = "CANCELLED" + + with pytest.raises( + RuntimeError, + match=f"Agent Engine query job {QUERY_OPERATION_NAME} completed with unexpected status CANCELLED.", + ): + self.hook.wait_for_query_agent_engine_job( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + ) + + mock_sleep.assert_not_called() + @mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True) @mock.patch(AGENT_ENGINE_STRING.format("time.monotonic"), autospec=True) @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.check_query_agent_engine_job"), autospec=True) @@ -233,6 +273,23 @@ def test_wait_for_agent_engine_operation_returns_when_done(self, mock_get_operat operation_name=OPERATION_NAME, ) + @mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_operation"), autospec=True) + def test_wait_for_agent_engine_operation_polls_until_done(self, mock_get_operation, mock_sleep): + running_operation = {"name": OPERATION_NAME, "done": False} + done_operation = {"name": OPERATION_NAME, "done": True} + mock_get_operation.side_effect = [running_operation, running_operation, done_operation] + + self.hook.wait_for_agent_engine_operation( + location=GCP_LOCATION, + operation_name=OPERATION_NAME, + poll_interval=10, + ) + + assert mock_get_operation.call_count == 3 + assert mock_sleep.call_count == 2 + mock_sleep.assert_called_with(10) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_operation"), autospec=True) def test_wait_for_agent_engine_operation_raises_on_error(self, mock_get_operation): mock_get_operation.return_value = {"name": OPERATION_NAME, "done": True, "error": {"message": "boom"}} @@ -258,3 +315,50 @@ def test_wait_for_agent_engine_operation_times_out(self, mock_get_operation, moc ) mock_sleep.assert_not_called() + + +class TestAgentEngineAsyncHook: + def setup_method(self): + with mock.patch( + BASE_STRING.format("GoogleBaseAsyncHook.__init__"), + return_value=None, + ): + self.hook = AgentEngineAsyncHook(gcp_conn_id=TEST_GCP_CONN_ID) + + @pytest.mark.asyncio + async def test_get_agent_engine_operation_calls_sync_hook(self): + sync_hook = mock.Mock(spec=AgentEngineHook) + sync_hook.get_agent_engine_operation.return_value = {"name": OPERATION_NAME, "done": True} + self.hook.get_sync_hook = mock.AsyncMock(return_value=sync_hook) + + result = await self.hook.get_agent_engine_operation( + location=GCP_LOCATION, + operation_name=OPERATION_NAME, + ) + + sync_hook.get_agent_engine_operation.assert_called_once_with( + location=GCP_LOCATION, + operation_name=OPERATION_NAME, + ) + assert result == {"name": OPERATION_NAME, "done": True} + + @pytest.mark.asyncio + async def test_check_query_agent_engine_job_calls_sync_hook(self): + sync_hook = mock.Mock(spec=AgentEngineHook) + sync_hook.check_query_agent_engine_job.return_value = mock.sentinel.query_job + self.hook.get_sync_hook = mock.AsyncMock(return_value=sync_hook) + + result = await self.hook.check_query_agent_engine_job( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + ) + + sync_hook.check_query_agent_engine_job.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_name=QUERY_OPERATION_NAME, + config=CHECK_QUERY_CONFIG, + ) + assert result == mock.sentinel.query_job diff --git a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py index 9de7f6da13099..954aa5747e779 100644 --- a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py @@ -369,6 +369,24 @@ def test_execute_does_not_wait_when_delete_operation_is_done(self, mock_hook, co mock_hook.return_value.wait_for_agent_engine_operation.assert_not_called() assert result == operation + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_raises_when_delete_operation_has_no_name(self, mock_hook, context): + mock_hook.return_value.delete_agent_engine.return_value = FakeModel({"done": False}) + op = DeleteAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + wait_for_completion=True, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + with pytest.raises(RuntimeError, match="Delete Agent Engine operation did not include"): + op.execute(context=context) + + mock_hook.return_value.wait_for_agent_engine_operation.assert_not_called() + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineDeleteTrigger"), autospec=True) @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) def test_execute_deferrable(self, mock_hook, mock_trigger, context): @@ -390,7 +408,6 @@ def test_execute_deferrable(self, mock_hook, mock_trigger, context): op.execute(context=context) mock_trigger.assert_called_once_with( - project_id=GCP_PROJECT, location=GCP_LOCATION, agent_engine_id=AGENT_ENGINE_ID, gcp_conn_id=GCP_CONN_ID, diff --git a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py index d15947c0d2f36..f823b35cfd9c3 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py @@ -50,7 +50,6 @@ def model_dump(self, mode="json"): @pytest.fixture def delete_trigger(): return AgentEngineDeleteTrigger( - project_id=GCP_PROJECT, location=GCP_LOCATION, agent_engine_id=AGENT_ENGINE_ID, gcp_conn_id=GCP_CONN_ID, @@ -80,7 +79,6 @@ def test_serialize(self, delete_trigger): assert delete_trigger.serialize() == ( "airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineDeleteTrigger", { - "project_id": GCP_PROJECT, "location": GCP_LOCATION, "agent_engine_id": AGENT_ENGINE_ID, "gcp_conn_id": GCP_CONN_ID, @@ -110,6 +108,29 @@ async def test_run_loop_return_success_event(self, mock_hook, delete_trigger): } ) + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.asyncio.sleep", autospec=True) + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) + async def test_run_loop_polls_until_success(self, mock_hook, mock_sleep, delete_trigger): + mock_hook.return_value.get_agent_engine_operation.side_effect = [ + {"done": False}, + {"done": False}, + {"done": True}, + ] + + event = await delete_trigger.run().asend(None) + + assert mock_hook.return_value.get_agent_engine_operation.call_count == 3 + assert mock_sleep.call_count == 2 + mock_sleep.assert_awaited_with(1) + assert event == TriggerEvent( + { + "status": "success", + "message": "Agent Engine deleted", + "agent_engine_id": AGENT_ENGINE_ID, + } + ) + @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.asyncio.sleep", autospec=True) @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) @@ -138,7 +159,7 @@ async def test_run_loop_return_error_event(self, mock_hook, delete_trigger): assert event == TriggerEvent( { "status": "error", - "message": "boom", + "message": "Failed while polling Agent Engine deletion: boom", "agent_engine_id": AGENT_ENGINE_ID, } ) @@ -218,6 +239,31 @@ async def test_run_loop_return_success_event(self, mock_hook, query_job_trigger) } ) + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.asyncio.sleep", autospec=True) + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) + async def test_run_loop_polls_until_success(self, mock_hook, mock_sleep, query_job_trigger): + running_job = FakeModel({"operation_name": QUERY_OPERATION_NAME, "status": "RUNNING"}) + success_job = FakeModel({"operation_name": QUERY_OPERATION_NAME, "status": "SUCCESS"}) + mock_hook.return_value.check_query_agent_engine_job.side_effect = [ + running_job, + running_job, + success_job, + ] + + event = await query_job_trigger.run().asend(None) + + assert mock_hook.return_value.check_query_agent_engine_job.call_count == 3 + assert mock_sleep.call_count == 2 + mock_sleep.assert_awaited_with(1) + assert event == TriggerEvent( + { + "status": "success", + "message": "Agent Engine query job completed", + "query_job": {"operation_name": QUERY_OPERATION_NAME, "status": "SUCCESS"}, + } + ) + @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) async def test_run_loop_return_failed_event(self, mock_hook, query_job_trigger): @@ -248,7 +294,26 @@ async def test_run_loop_return_timeout_event(self, mock_hook, mock_sleep, query_ { "status": "timeout", "message": f"Timed out waiting for Agent Engine query job {QUERY_OPERATION_NAME}", - "operation_name": QUERY_OPERATION_NAME, + "query_job": {"status": "RUNNING"}, + } + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) + async def test_run_loop_return_error_event_for_unexpected_status(self, mock_hook, query_job_trigger): + query_job = {"operation_name": QUERY_OPERATION_NAME, "status": "CANCELLED"} + mock_hook.return_value.check_query_agent_engine_job.return_value = FakeModel(query_job) + + event = await query_job_trigger.run().asend(None) + + assert event == TriggerEvent( + { + "status": "error", + "message": ( + f"Agent Engine query job {QUERY_OPERATION_NAME} completed with " + "unexpected status CANCELLED." + ), + "query_job": query_job, } ) @@ -262,7 +327,7 @@ async def test_run_loop_return_error_event(self, mock_hook, query_job_trigger): assert event == TriggerEvent( { "status": "error", - "message": "boom", - "operation_name": QUERY_OPERATION_NAME, + "message": "Failed while polling Agent Engine query job: boom", + "query_job": {"operation_name": QUERY_OPERATION_NAME}, } ) From e012207af76b7a67e70b4b8d6aad610cda6d34f1 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Sat, 20 Jun 2026 17:50:23 -0300 Subject: [PATCH 23/32] Clarify Agent Engine delete operation docs --- providers/google/docs/operators/cloud/vertex_ai.rst | 2 +- .../google/cloud/operators/vertex_ai/agent_engine.py | 2 +- .../airflow/providers/google/cloud/triggers/vertex_ai.py | 2 +- .../cloud/vertex_ai/example_vertex_ai_agent_engine.py | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/providers/google/docs/operators/cloud/vertex_ai.rst b/providers/google/docs/operators/cloud/vertex_ai.rst index 510eb82e4a348..ab15e4ac57b33 100644 --- a/providers/google/docs/operators/cloud/vertex_ai.rst +++ b/providers/google/docs/operators/cloud/vertex_ai.rst @@ -85,7 +85,7 @@ To update an Agent Engine you can use To delete an Agent Engine you can use :class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.DeleteAgentEngineOperator`. -By default, the operator waits until the Agent Engine no longer exists. The same operation +By default, the operator waits until the delete operation completes. The same operation can be performed in the deferrable mode. .. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py index 34db575fc1737..df6f1478b3630 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -382,7 +382,7 @@ class DeleteAgentEngineOperator(GoogleCloudBaseOperator): :param agent_engine_id: Required. The Agent Engine ID. :param force: Optional. Whether to delete child resources. :param config: Optional. Additional deletion configuration. - :param wait_for_completion: Whether to wait until the Agent Engine no longer exists. + :param wait_for_completion: Whether to wait until the delete operation completes. :param poll_interval: Time, in seconds, to wait between checks. :param timeout: Optional timeout, in seconds. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py index 0f9b3b6984cc0..b82a416d26df1 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py @@ -131,7 +131,7 @@ def _serialize_job(self, job: Any) -> Any: class AgentEngineDeleteTrigger(BaseTrigger): - """Trigger that waits until a Vertex AI Agent Engine no longer exists.""" + """Trigger that waits until a Vertex AI Agent Engine delete operation completes.""" def __init__( self, diff --git a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py index 99fcd2b41a9c7..b59cfd6bd8e4d 100644 --- a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py +++ b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py @@ -17,7 +17,7 @@ # under the License. """ -Example Airflow DAG for Google Vertex AI Agent Engine operations. +Example Airflow Dag for Google Vertex AI Agent Engine operations. """ from __future__ import annotations @@ -179,10 +179,10 @@ from tests_common.test_utils.watcher import watcher # This test needs watcher in order to properly mark success/failure - # when "tearDown" task with trigger rule is part of the DAG + # when "tearDown" task with trigger rule is part of the Dag list(dag.tasks) >> watcher() from tests_common.test_utils.system_tests import get_test_run # noqa: E402 -# Needed to run the example DAG with pytest (see: contributing-docs/testing/system_tests.rst) +# Needed to run the example Dag with pytest (see: contributing-docs/testing/system_tests.rst) test_run = get_test_run(dag) From 7025dde68792289ec67866e0b9bdfa7d82c8bc76 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Wed, 24 Jun 2026 14:43:59 -0300 Subject: [PATCH 24/32] Tighten Agent Engine delete operation test --- .../google/cloud/operators/vertex_ai/test_agent_engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py index 954aa5747e779..c852d685c6877 100644 --- a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py @@ -382,7 +382,9 @@ def test_execute_raises_when_delete_operation_has_no_name(self, mock_hook, conte impersonation_chain=IMPERSONATION_CHAIN, ) - with pytest.raises(RuntimeError, match="Delete Agent Engine operation did not include"): + with pytest.raises( + RuntimeError, match=r"Delete Agent Engine operation did not include an operation name\." + ): op.execute(context=context) mock_hook.return_value.wait_for_agent_engine_operation.assert_not_called() From 4455a29fff9a9ce339061e160f83ac6e2805b28d Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Fri, 26 Jun 2026 17:11:03 -0300 Subject: [PATCH 25/32] Improve Vertex AI Agent Engine query operators --- .../google/docs/operators/cloud/vertex_ai.rst | 31 +-- .../cloud/hooks/vertex_ai/agent_engine.py | 121 ++++++-- .../cloud/operators/vertex_ai/agent_engine.py | 177 +++++------- .../google/cloud/triggers/vertex_ai.py | 119 +------- .../example_vertex_ai_agent_engine.py | 47 ++-- .../hooks/vertex_ai/test_agent_engine.py | 103 +++++-- .../operators/vertex_ai/test_agent_engine.py | 262 ++++++++++-------- .../triggers/test_vertex_ai_agent_engine.py | 146 +--------- 8 files changed, 437 insertions(+), 569 deletions(-) diff --git a/providers/google/docs/operators/cloud/vertex_ai.rst b/providers/google/docs/operators/cloud/vertex_ai.rst index ab15e4ac57b33..96710185a922f 100644 --- a/providers/google/docs/operators/cloud/vertex_ai.rst +++ b/providers/google/docs/operators/cloud/vertex_ai.rst @@ -48,31 +48,19 @@ To get an Agent Engine you can use :end-before: [END how_to_cloud_vertex_ai_get_agent_engine_operator] To run a query job on an Agent Engine you can use -:class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.QueryAgentEngineOperator`. +:class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.RunQueryJobOperator`. The operator uses the public ``run_query_job`` SDK method. The ``config`` parameter -is required and must include ``query`` and ``output_gcs_uri``. The SDK writes query -input and output through Google Cloud Storage and the operator returns the serialized -query job result metadata, including the job name and GCS URIs. This operator starts -the query job and does not wait for the query job output. +can include ``query`` and ``output_gcs_uri``. The SDK writes query input and output +through Google Cloud Storage. By default, the operator waits for the query job to +complete and returns the serialized query job result. Set ``retrieve_result`` to +``True`` in ``check_config`` to return the query job result from Google Cloud Storage. +The same operation can be performed in the deferrable mode. .. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py :language: python :dedent: 4 - :start-after: [START how_to_cloud_vertex_ai_query_agent_engine_operator] - :end-before: [END how_to_cloud_vertex_ai_query_agent_engine_operator] - -To wait for a query job and retrieve its output you can use -:class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.CheckQueryAgentEngineOperator`. -The operator uses the public ``check_query_job`` SDK method. The ``config`` parameter -is optional. Set ``retrieve_result`` to ``True`` in ``config`` to return the query job -result from Google Cloud Storage. The same operation can be performed in the -deferrable mode. - -.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py - :language: python - :dedent: 4 - :start-after: [START how_to_cloud_vertex_ai_check_query_agent_engine_operator] - :end-before: [END how_to_cloud_vertex_ai_check_query_agent_engine_operator] + :start-after: [START how_to_cloud_vertex_ai_run_query_job_operator] + :end-before: [END how_to_cloud_vertex_ai_run_query_job_operator] To update an Agent Engine you can use :class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.UpdateAgentEngineOperator`. @@ -85,8 +73,7 @@ To update an Agent Engine you can use To delete an Agent Engine you can use :class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.DeleteAgentEngineOperator`. -By default, the operator waits until the delete operation completes. The same operation -can be performed in the deferrable mode. +By default, the operator waits until the delete operation completes. .. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py :language: python diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index 4ba8afe76c084..a0e016d192729 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -37,16 +37,28 @@ from vertexai._genai import types -def _serialize_value(value: Any) -> Any: +VERTEX_AI_AGENT_ENGINE_API_VERSION = "v1beta1" +VERTEX_AI_AGENT_ENGINE_OPERATION_URL = ( + "https://{location}-aiplatform.googleapis.com/{api_version}/{operation_name}" +) +DEFAULT_AGENT_ENGINE_OPERATION_REQUEST_TIMEOUT = 60.0 + + +def extract_operation_id(operation_name: str) -> str: + """Extract the operation ID from a fully qualified operation name.""" + return operation_name.rstrip("/").split("/")[-1] + + +def serialize_value(value: Any) -> Any: """Recursively convert SDK model objects to JSON-serializable types.""" if hasattr(value, "model_dump"): return value.model_dump(mode="json") if isinstance(value, dict): - return {key: _serialize_value(item) for key, item in value.items()} + return {key: serialize_value(item) for key, item in value.items()} if isinstance(value, list): - return [_serialize_value(item) for item in value] + return [serialize_value(item) for item in value] if isinstance(value, tuple): - return tuple(_serialize_value(item) for item in value) + return tuple(serialize_value(item) for item in value) return value @@ -78,10 +90,21 @@ def build_agent_engine_name(project_id: str, location: str, agent_engine_id: str """Build a fully qualified Agent Engine resource name.""" return f"projects/{project_id}/locations/{location}/reasoningEngines/{agent_engine_id}" + @staticmethod + def extract_operation_id(operation_name: str) -> str: + """Extract the operation ID from a fully qualified operation name.""" + return extract_operation_id(operation_name) + + @staticmethod + def build_operation_name(project_id: str, location: str, operation_id: str) -> str: + """Build a fully qualified Agent Engine operation name.""" + return f"projects/{project_id}/locations/{location}/operations/{operation_id}" + @GoogleBaseHook.fallback_to_default_project_id def create_agent_engine( self, location: str, + agent_engine: Any | None = None, agent: Any | None = None, config: types.AgentEngineConfigOrDict | None = None, project_id: str = PROVIDE_PROJECT_ID, @@ -90,19 +113,21 @@ def create_agent_engine( Create an Agent Engine. :param location: Required. The ID of the Google Cloud location that the service belongs to. + :param agent_engine: Optional. Deprecated alias for ``agent``. :param agent: Optional. The agent object to deploy. :param config: Optional. Configuration for the Agent Engine. :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project configured in the connection. """ client = self.get_agent_engine_client(project_id=project_id, location=location) - return client.create(agent=agent, config=config) + return client.create(agent_engine=agent_engine, agent=agent, config=config) @GoogleBaseHook.fallback_to_default_project_id def get_agent_engine( self, location: str, agent_engine_id: str, + config: types.GetAgentEngineConfigOrDict | None = None, project_id: str = PROVIDE_PROJECT_ID, ) -> types.AgentEngine: """ @@ -110,19 +135,20 @@ def get_agent_engine( :param location: Required. The ID of the Google Cloud location that the service belongs to. :param agent_engine_id: Required. The Agent Engine ID. + :param config: Optional. Configuration for getting the Agent Engine. :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project configured in the connection. """ client = self.get_agent_engine_client(project_id=project_id, location=location) name = self.build_agent_engine_name(project_id, location, agent_engine_id) - return client.get(name=name) + return client.get(name=name, config=config) @GoogleBaseHook.fallback_to_default_project_id - def query_agent_engine( + def run_query_job( self, location: str, agent_engine_id: str, - config: types.RunQueryJobAgentEngineConfigOrDict, + config: types.RunQueryJobAgentEngineConfigOrDict | None = None, project_id: str = PROVIDE_PROJECT_ID, ) -> types.RunQueryJobResult: """ @@ -130,7 +156,7 @@ def query_agent_engine( :param location: Required. The ID of the Google Cloud location that the service belongs to. :param agent_engine_id: Required. The Agent Engine ID. - :param config: Required. Configuration for the query job (``query``, ``output_gcs_uri``). + :param config: Optional. Configuration for the query job (``query``, ``output_gcs_uri``). :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project configured in the connection. """ @@ -142,7 +168,7 @@ def query_agent_engine( def check_query_agent_engine_job( self, location: str, - operation_name: str, + operation_id: str, config: types.CheckQueryJobAgentEngineConfigOrDict | None = None, project_id: str = PROVIDE_PROJECT_ID, ) -> types.CheckQueryJobResult: @@ -150,19 +176,20 @@ def check_query_agent_engine_job( Check a query job on an Agent Engine. :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param operation_name: Required. The query job operation name. + :param operation_id: Required. The query job operation ID. :param config: Optional. Configuration for checking the query job. :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project configured in the connection. """ client = self.get_agent_engine_client(project_id=project_id, location=location) + operation_name = self.build_operation_name(project_id, location, operation_id) return client.check_query_job(name=operation_name, config=config) @GoogleBaseHook.fallback_to_default_project_id def wait_for_query_agent_engine_job( self, location: str, - operation_name: str, + operation_id: str, config: types.CheckQueryJobAgentEngineConfigOrDict | None = None, poll_interval: float = 30, timeout: float | None = None, @@ -172,7 +199,7 @@ def wait_for_query_agent_engine_job( Wait until an Agent Engine query job completes. :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param operation_name: Required. The query job operation name. + :param operation_id: Required. The query job operation ID. :param config: Optional. Configuration for checking the query job. :param poll_interval: Time, in seconds, to wait between checks. :param timeout: Optional timeout, in seconds. @@ -180,11 +207,12 @@ def wait_for_query_agent_engine_job( configured in the connection. """ start_time = time.monotonic() + operation_name = self.build_operation_name(project_id, location, operation_id) while True: query_job = self.check_query_agent_engine_job( project_id=project_id, location=location, - operation_name=operation_name, + operation_id=operation_id, config=config, ) status = getattr(query_job, "status", None) @@ -208,6 +236,7 @@ def update_agent_engine( agent_engine_id: str, config: types.AgentEngineConfigOrDict, agent: Any | None = None, + agent_engine: Any | None = None, project_id: str = PROVIDE_PROJECT_ID, ) -> types.AgentEngine: """ @@ -217,12 +246,13 @@ def update_agent_engine( :param agent_engine_id: Required. The Agent Engine ID. :param config: Required. Configuration for the Agent Engine update. :param agent: Optional. The updated agent object to deploy. + :param agent_engine: Optional. Deprecated alias for ``agent``. :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project configured in the connection. """ client = self.get_agent_engine_client(project_id=project_id, location=location) name = self.build_agent_engine_name(project_id, location, agent_engine_id) - return client.update(name=name, agent=agent, config=config) + return client.update(name=name, agent=agent, agent_engine=agent_engine, config=config) @GoogleBaseHook.fallback_to_default_project_id def delete_agent_engine( @@ -248,36 +278,61 @@ def delete_agent_engine( name = self.build_agent_engine_name(project_id, location, agent_engine_id) return client.delete(name=name, force=force, config=config) - def get_agent_engine_operation(self, location: str, operation_name: str) -> dict[str, Any]: - """Return a Vertex AI Agent Engine long-running operation.""" - url = ( - operation_name - if operation_name.startswith("http") - else f"https://{location}-aiplatform.googleapis.com/v1beta1/{operation_name}" + @GoogleBaseHook.fallback_to_default_project_id + def get_agent_engine_operation( + self, + location: str, + operation_id: str, + request_timeout: float | None = DEFAULT_AGENT_ENGINE_OPERATION_REQUEST_TIMEOUT, + project_id: str = PROVIDE_PROJECT_ID, + ) -> dict[str, Any]: + """ + Return a Vertex AI Agent Engine long-running operation. + + :param location: The ID of the Google Cloud location that the service belongs to. + :param operation_id: The Agent Engine operation ID. + :param request_timeout: Optional timeout, in seconds, for the operation request. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. + """ + operation_name = self.build_operation_name(project_id, location, operation_id) + url = VERTEX_AI_AGENT_ENGINE_OPERATION_URL.format( + location=location, + api_version=VERTEX_AI_AGENT_ENGINE_API_VERSION, + operation_name=operation_name, ) session = google.auth.transport.requests.AuthorizedSession(self.get_credentials()) - response = session.get(url) + response = session.get(url, timeout=request_timeout) response.raise_for_status() return response.json() + @GoogleBaseHook.fallback_to_default_project_id def wait_for_agent_engine_operation( self, location: str, - operation_name: str, + operation_id: str, poll_interval: float = 30, timeout: float | None = None, + project_id: str = PROVIDE_PROJECT_ID, ) -> None: """ Wait until an Agent Engine operation completes. :param location: The ID of the Google Cloud location that the service belongs to. - :param operation_name: The Agent Engine operation name. + :param operation_id: The Agent Engine operation ID. :param poll_interval: Time, in seconds, to wait between checks. :param timeout: Optional timeout, in seconds. + :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project + configured in the connection. """ start_time = time.monotonic() + operation_name = self.build_operation_name(project_id, location, operation_id) while True: - operation = self.get_agent_engine_operation(location=location, operation_name=operation_name) + operation = self.get_agent_engine_operation( + project_id=project_id, + location=location, + operation_id=operation_id, + ) if operation.get("done"): if operation.get("error"): raise RuntimeError( @@ -307,18 +362,26 @@ def __init__( **kwargs, ) - async def get_agent_engine_operation(self, location: str, operation_name: str) -> dict[str, Any]: + async def get_agent_engine_operation( + self, + location: str, + operation_id: str, + request_timeout: float | None = DEFAULT_AGENT_ENGINE_OPERATION_REQUEST_TIMEOUT, + project_id: str = PROVIDE_PROJECT_ID, + ) -> dict[str, Any]: """Return a Vertex AI Agent Engine long-running operation.""" sync_hook = await self.get_sync_hook() return await sync_to_async(sync_hook.get_agent_engine_operation)( + project_id=project_id, location=location, - operation_name=operation_name, + operation_id=operation_id, + request_timeout=request_timeout, ) async def check_query_agent_engine_job( self, location: str, - operation_name: str, + operation_id: str, config: types.CheckQueryJobAgentEngineConfigOrDict | None = None, project_id: str = PROVIDE_PROJECT_ID, ) -> types.CheckQueryJobResult: @@ -327,6 +390,6 @@ async def check_query_agent_engine_job( return await sync_to_async(sync_hook.check_query_agent_engine_job)( project_id=project_id, location=location, - operation_name=operation_name, + operation_id=operation_id, config=config, ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py index df6f1478b3630..f576c3f06905c 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -24,12 +24,13 @@ from typing import TYPE_CHECKING, Any from airflow.providers.common.compat.sdk import conf -from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineHook, _serialize_value -from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator -from airflow.providers.google.cloud.triggers.vertex_ai import ( - AgentEngineDeleteTrigger, - AgentEngineQueryJobTrigger, +from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import ( + AgentEngineHook, + extract_operation_id, + serialize_value, ) +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator +from airflow.providers.google.cloud.triggers.vertex_ai import AgentEngineQueryJobTrigger if TYPE_CHECKING: from vertexai._genai import types @@ -40,8 +41,8 @@ def _serialize_agent_engine(agent_engine: types.AgentEngine) -> dict[str, Any]: api_resource = getattr(agent_engine, "api_resource", None) if api_resource is not None: - return _serialize_value(api_resource) - return _serialize_value(agent_engine) + return serialize_value(api_resource) + return serialize_value(agent_engine) class CreateAgentEngineOperator(GoogleCloudBaseOperator): @@ -51,6 +52,7 @@ class CreateAgentEngineOperator(GoogleCloudBaseOperator): :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param agent: Optional. The agent object to deploy. + :param agent_engine: Optional. Deprecated alias for ``agent``. :param config: Optional. Configuration for the Agent Engine. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials. @@ -60,6 +62,7 @@ class CreateAgentEngineOperator(GoogleCloudBaseOperator): "project_id", "location", "agent", + "agent_engine", "config", "gcp_conn_id", "impersonation_chain", @@ -71,6 +74,7 @@ def __init__( project_id: str, location: str, agent: Any | None = None, + agent_engine: Any | None = None, config: types.AgentEngineConfigOrDict | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, @@ -80,6 +84,7 @@ def __init__( self.project_id = project_id self.location = location self.agent = agent + self.agent_engine = agent_engine self.config = config self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain @@ -97,6 +102,7 @@ def execute(self, context: Context) -> dict[str, Any]: project_id=self.project_id, location=self.location, agent=self.agent, + agent_engine=self.agent_engine, config=self.config, ) result = _serialize_agent_engine(agent_engine) @@ -111,56 +117,7 @@ class GetAgentEngineOperator(GoogleCloudBaseOperator): :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param agent_engine_id: Required. The Agent Engine ID. - :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term credentials. - """ - - template_fields = ("project_id", "location", "agent_engine_id", "gcp_conn_id", "impersonation_chain") - - def __init__( - self, - *, - project_id: str, - location: str, - agent_engine_id: str, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.project_id = project_id - self.location = location - self.agent_engine_id = agent_engine_id - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - @cached_property - def hook(self) -> AgentEngineHook: - return AgentEngineHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - - def execute(self, context: Context) -> dict[str, Any]: - self.log.info("Getting Agent Engine %s.", self.agent_engine_id) - agent_engine = self.hook.get_agent_engine( - project_id=self.project_id, - location=self.location, - agent_engine_id=self.agent_engine_id, - ) - result = _serialize_agent_engine(agent_engine) - self.log.info("Agent Engine %s was retrieved.", self.agent_engine_id) - return result - - -class QueryAgentEngineOperator(GoogleCloudBaseOperator): - """ - Run a query job on a Vertex AI Agent Engine. - - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param agent_engine_id: Required. The Agent Engine ID. - :param config: Required. Configuration for the query job (``query``, ``output_gcs_uri``). + :param config: Optional. Configuration for getting the Agent Engine. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials. """ @@ -180,7 +137,7 @@ def __init__( project_id: str, location: str, agent_engine_id: str, - config: types.RunQueryJobAgentEngineConfigOrDict, + config: types.GetAgentEngineConfigOrDict | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, @@ -201,25 +158,28 @@ def hook(self) -> AgentEngineHook: ) def execute(self, context: Context) -> dict[str, Any]: - self.log.info("Running query job on Agent Engine %s.", self.agent_engine_id) - query_job = self.hook.query_agent_engine( + self.log.info("Getting Agent Engine %s.", self.agent_engine_id) + agent_engine = self.hook.get_agent_engine( project_id=self.project_id, location=self.location, agent_engine_id=self.agent_engine_id, config=self.config, ) - self.log.info("Query job was started on Agent Engine %s.", self.agent_engine_id) - return _serialize_value(query_job) + result = _serialize_agent_engine(agent_engine) + self.log.info("Agent Engine %s was retrieved.", self.agent_engine_id) + return result -class CheckQueryAgentEngineOperator(GoogleCloudBaseOperator): +class RunQueryJobOperator(GoogleCloudBaseOperator): """ - Check a query job on a Vertex AI Agent Engine. + Run a query job on a Vertex AI Agent Engine. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param operation_name: Required. The query job operation name (e.g. from the ``job_name`` field of the result of ``QueryAgentEngineOperator``). - :param config: Optional. Configuration for checking the query job. + :param agent_engine_id: Required. The Agent Engine ID. + :param config: Optional. Configuration for the query job (``query``, ``output_gcs_uri``). + :param check_config: Optional. Configuration for checking the query job. + :param wait_for_completion: Whether to wait until the query job completes. :param poll_interval: Time, in seconds, to wait between checks. :param timeout: Optional timeout, in seconds. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. @@ -230,8 +190,9 @@ class CheckQueryAgentEngineOperator(GoogleCloudBaseOperator): template_fields = ( "project_id", "location", - "operation_name", + "agent_engine_id", "config", + "check_config", "gcp_conn_id", "impersonation_chain", ) @@ -241,8 +202,10 @@ def __init__( *, project_id: str, location: str, - operation_name: str, - config: types.CheckQueryJobAgentEngineConfigOrDict | None = None, + agent_engine_id: str, + config: types.RunQueryJobAgentEngineConfigOrDict | None = None, + check_config: types.CheckQueryJobAgentEngineConfigOrDict | None = None, + wait_for_completion: bool = True, poll_interval: float = 30, timeout: float | None = None, gcp_conn_id: str = "google_cloud_default", @@ -253,8 +216,10 @@ def __init__( super().__init__(**kwargs) self.project_id = project_id self.location = location - self.operation_name = operation_name + self.agent_engine_id = agent_engine_id self.config = config + self.check_config = check_config + self.wait_for_completion = wait_for_completion self.poll_interval = poll_interval self.timeout = timeout self.gcp_conn_id = gcp_conn_id @@ -269,14 +234,30 @@ def hook(self) -> AgentEngineHook: ) def execute(self, context: Context) -> dict[str, Any]: - self.log.info("Checking Agent Engine query job %s.", self.operation_name) + self.log.info("Running query job on Agent Engine %s.", self.agent_engine_id) + query_job = self.hook.run_query_job( + project_id=self.project_id, + location=self.location, + agent_engine_id=self.agent_engine_id, + config=self.config, + ) + result = serialize_value(query_job) + self.log.info("Query job was started on Agent Engine %s.", self.agent_engine_id) + if not self.wait_for_completion: + return result + + operation_name = getattr(query_job, "job_name", None) + if not operation_name: + raise RuntimeError("Agent Engine query job did not include an operation name.") + operation_id = extract_operation_id(operation_name) + if self.deferrable: self.defer( trigger=AgentEngineQueryJobTrigger( project_id=self.project_id, location=self.location, - operation_name=self.operation_name, - config=self.config, + operation_id=operation_id, + config=self.check_config, gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, poll_interval=self.poll_interval, @@ -288,14 +269,13 @@ def execute(self, context: Context) -> dict[str, Any]: query_job = self.hook.wait_for_query_agent_engine_job( project_id=self.project_id, location=self.location, - operation_name=self.operation_name, - config=self.config, + operation_id=operation_id, + config=self.check_config, poll_interval=self.poll_interval, timeout=self.timeout, ) - result = _serialize_value(query_job) - self.log.info("Agent Engine query job %s completed.", self.operation_name) - return result + self.log.info("Agent Engine query job %s completed.", operation_name) + return serialize_value(query_job) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: if event is None: @@ -316,6 +296,7 @@ class UpdateAgentEngineOperator(GoogleCloudBaseOperator): :param location: Required. The ID of the Google Cloud location that the service belongs to. :param agent_engine_id: Required. The Agent Engine ID. :param agent: Optional. The updated agent object to deploy. + :param agent_engine: Optional. Deprecated alias for ``agent``. :param config: Required. Configuration for the Agent Engine update. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials. @@ -326,6 +307,7 @@ class UpdateAgentEngineOperator(GoogleCloudBaseOperator): "location", "agent_engine_id", "agent", + "agent_engine", "config", "gcp_conn_id", "impersonation_chain", @@ -339,6 +321,7 @@ def __init__( agent_engine_id: str, config: types.AgentEngineConfigOrDict, agent: Any | None = None, + agent_engine: Any | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, @@ -348,6 +331,7 @@ def __init__( self.location = location self.agent_engine_id = agent_engine_id self.agent = agent + self.agent_engine = agent_engine self.config = config self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain @@ -366,6 +350,7 @@ def execute(self, context: Context) -> dict[str, Any]: location=self.location, agent_engine_id=self.agent_engine_id, agent=self.agent, + agent_engine=self.agent_engine, config=self.config, ) result = _serialize_agent_engine(agent_engine) @@ -387,7 +372,6 @@ class DeleteAgentEngineOperator(GoogleCloudBaseOperator): :param timeout: Optional timeout, in seconds. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials. - :param deferrable: Run operator in the deferrable mode. """ template_fields = ( @@ -413,7 +397,6 @@ def __init__( timeout: float | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, - deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -427,7 +410,6 @@ def __init__( self.timeout = timeout self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - self.deferrable = deferrable @cached_property def hook(self) -> AgentEngineHook: @@ -445,50 +427,25 @@ def execute(self, context: Context) -> dict[str, Any]: force=self.force, config=self.config, ) - result = _serialize_value(operation) + result = serialize_value(operation) if not self.wait_for_completion: return result operation_name = getattr(operation, "name", None) if not operation_name: raise RuntimeError("Delete Agent Engine operation did not include an operation name.") + operation_id = extract_operation_id(operation_name) if getattr(operation, "done", False): self.log.info("Agent Engine %s was deleted.", self.agent_engine_id) return result - if self.deferrable: - self.defer( - trigger=AgentEngineDeleteTrigger( - location=self.location, - agent_engine_id=self.agent_engine_id, - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - poll_interval=self.poll_interval, - timeout=self.timeout, - operation_name=operation_name, - ), - method_name="execute_complete", - kwargs={"operation": result}, - ) - self.hook.wait_for_agent_engine_operation( + project_id=self.project_id, location=self.location, - operation_name=operation_name, + operation_id=operation_id, poll_interval=self.poll_interval, timeout=self.timeout, ) self.log.info("Agent Engine %s was deleted.", self.agent_engine_id) return result - - def execute_complete( - self, context: Context, event: dict[str, Any] | None = None, operation: dict[str, Any] | None = None - ) -> dict[str, Any]: - if event is None: - raise RuntimeError("No event received in trigger callback") - if event["status"] == "success": - self.log.info("Agent Engine %s deleted.", event["agent_engine_id"]) - return operation or {} - if event["status"] == "timeout": - raise TimeoutError(event["message"]) - raise RuntimeError(event["message"]) diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py index b82a416d26df1..6e7ecff02fc60 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/vertex_ai.py @@ -31,7 +31,7 @@ ) from airflow.providers.common.compat.sdk import AirflowException -from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineAsyncHook, _serialize_value +from airflow.providers.google.cloud.hooks.vertex_ai.agent_engine import AgentEngineAsyncHook, serialize_value from airflow.providers.google.cloud.hooks.vertex_ai.batch_prediction_job import BatchPredictionJobAsyncHook from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobAsyncHook from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import ( @@ -130,101 +130,6 @@ def _serialize_job(self, job: Any) -> Any: return self.job_serializer_class.to_dict(job) -class AgentEngineDeleteTrigger(BaseTrigger): - """Trigger that waits until a Vertex AI Agent Engine delete operation completes.""" - - def __init__( - self, - location: str, - agent_engine_id: str, - operation_name: str, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - poll_interval: float = 30, - timeout: float | None = None, - ): - super().__init__() - self.location = location - self.agent_engine_id = agent_engine_id - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - self.poll_interval = poll_interval - self.timeout = timeout - self.operation_name = operation_name - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - "airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineDeleteTrigger", - { - "location": self.location, - "agent_engine_id": self.agent_engine_id, - "gcp_conn_id": self.gcp_conn_id, - "impersonation_chain": self.impersonation_chain, - "poll_interval": self.poll_interval, - "timeout": self.timeout, - "operation_name": self.operation_name, - }, - ) - - @cached_property - def async_hook(self) -> AgentEngineAsyncHook: - return AgentEngineAsyncHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - - async def run(self) -> AsyncIterator[TriggerEvent]: - start_time = time.monotonic() - try: - while True: - operation = await self.async_hook.get_agent_engine_operation( - location=self.location, - operation_name=self.operation_name, - ) - if operation.get("done"): - if operation.get("error"): - yield TriggerEvent( - { - "status": "error", - "message": str(operation["error"]), - "agent_engine_id": self.agent_engine_id, - } - ) - return - yield TriggerEvent( - { - "status": "success", - "message": "Agent Engine deleted", - "agent_engine_id": self.agent_engine_id, - } - ) - return - - if self.timeout is not None and time.monotonic() - start_time >= self.timeout: - yield TriggerEvent( - { - "status": "timeout", - "message": ( - f"Timed out waiting for Agent Engine {self.agent_engine_id} to be deleted" - ), - "agent_engine_id": self.agent_engine_id, - } - ) - return - - self.log.info("Waiting for Agent Engine %s to be deleted.", self.agent_engine_id) - await asyncio.sleep(self.poll_interval) - except Exception as err: - self.log.exception("Exception occurred while waiting for Agent Engine deletion.") - yield TriggerEvent( - { - "status": "error", - "message": f"Failed while polling Agent Engine deletion: {err}", - "agent_engine_id": self.agent_engine_id, - } - ) - - class AgentEngineQueryJobTrigger(BaseTrigger): """Trigger that waits until a Vertex AI Agent Engine query job completes.""" @@ -232,7 +137,7 @@ def __init__( self, project_id: str, location: str, - operation_name: str, + operation_id: str, config: vertexai_types.CheckQueryJobAgentEngineConfigOrDict | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, @@ -242,7 +147,7 @@ def __init__( super().__init__() self.project_id = project_id self.location = location - self.operation_name = operation_name + self.operation_id = operation_id self.config = config self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain @@ -255,8 +160,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: { "project_id": self.project_id, "location": self.location, - "operation_name": self.operation_name, - "config": _serialize_value(self.config), + "operation_id": self.operation_id, + "config": serialize_value(self.config), "gcp_conn_id": self.gcp_conn_id, "impersonation_chain": self.impersonation_chain, "poll_interval": self.poll_interval, @@ -278,11 +183,11 @@ async def run(self) -> AsyncIterator[TriggerEvent]: query_job = await self.async_hook.check_query_agent_engine_job( project_id=self.project_id, location=self.location, - operation_name=self.operation_name, + operation_id=self.operation_id, config=self.config, ) status = getattr(query_job, "status", None) - serialized_query_job = _serialize_value(query_job) + serialized_query_job = serialize_value(query_job) if status == "SUCCESS": yield TriggerEvent( { @@ -296,7 +201,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: yield TriggerEvent( { "status": "error", - "message": f"Agent Engine query job {self.operation_name} failed.", + "message": f"Agent Engine query job {self.operation_id} failed.", "query_job": serialized_query_job, } ) @@ -306,7 +211,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: { "status": "error", "message": ( - f"Agent Engine query job {self.operation_name} completed with " + f"Agent Engine query job {self.operation_id} completed with " f"unexpected status {status}." ), "query_job": serialized_query_job, @@ -318,13 +223,13 @@ async def run(self) -> AsyncIterator[TriggerEvent]: yield TriggerEvent( { "status": "timeout", - "message": f"Timed out waiting for Agent Engine query job {self.operation_name}", + "message": f"Timed out waiting for Agent Engine query job {self.operation_id}", "query_job": serialized_query_job, } ) return - self.log.info("Waiting for Agent Engine query job %s to complete.", self.operation_name) + self.log.info("Waiting for Agent Engine query job %s to complete.", self.operation_id) await asyncio.sleep(self.poll_interval) except Exception as err: self.log.exception("Exception occurred while waiting for Agent Engine query job.") @@ -332,7 +237,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: { "status": "error", "message": f"Failed while polling Agent Engine query job: {err}", - "query_job": {"operation_name": self.operation_name}, + "query_job": {"operation_id": self.operation_id}, } ) diff --git a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py index b59cfd6bd8e4d..4be41d802c998 100644 --- a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py +++ b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py @@ -22,6 +22,7 @@ from __future__ import annotations +import json import os from datetime import datetime from typing import TYPE_CHECKING @@ -31,11 +32,10 @@ from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator from airflow.providers.google.cloud.operators.vertex_ai.agent_engine import ( - CheckQueryAgentEngineOperator, CreateAgentEngineOperator, DeleteAgentEngineOperator, GetAgentEngineOperator, - QueryAgentEngineOperator, + RunQueryJobOperator, UpdateAgentEngineOperator, ) @@ -56,15 +56,27 @@ ) AGENT_ENGINE_ID = "{{ task_instance.xcom_pull(task_ids='create_agent_engine')['name'].split('/')[-1] }}" -QUERY_OPERATION_NAME = "{{ task_instance.xcom_pull(task_ids='query_agent_engine')['job_name'] }}" BUCKET_NAME = f"bucket-{DAG_ID}-{ENV_ID}".replace("_", "-") DISPLAY_NAME = f"airflow-agent-engine-{ENV_ID}" QUERY_CONFIG: vertexai_types.RunQueryJobAgentEngineConfigDict = { - "query": "Respond with a short acknowledgement.", + "query": json.dumps( + { + "dag_id": "vertex_ai_agent_engine_operations", + "run_id": "manual__agent_engine_system_test", + "dag_file": "providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py", + "failed_task": { + "task_id": "transform", + "state": "failed", + "try_number": 1, + }, + "log_excerpt": "KeyError: 'rowz'", + } + ), "output_gcs_uri": f"gs://{BUCKET_NAME}/query-output/", } + with DAG( DAG_ID, schedule="@once", @@ -85,6 +97,10 @@ "max_instances": 1, "resource_limits": {"cpu": "1", "memory": "1Gi"}, "container_spec": {"image_uri": CONTAINER_URI}, + "env_vars": { + "AGENT_USE_MODEL": "false", + "AGENT_USE_MOCKS": "true", + }, "class_methods": [ { "name": "query", @@ -110,26 +126,17 @@ project_id=PROJECT_ID, ) - # [START how_to_cloud_vertex_ai_query_agent_engine_operator] - query_agent_engine = QueryAgentEngineOperator( - task_id="query_agent_engine", + # [START how_to_cloud_vertex_ai_run_query_job_operator] + run_query_job = RunQueryJobOperator( + task_id="run_query_job", project_id=PROJECT_ID, location=LOCATION, agent_engine_id=AGENT_ENGINE_ID, config=QUERY_CONFIG, - ) - # [END how_to_cloud_vertex_ai_query_agent_engine_operator] - - # [START how_to_cloud_vertex_ai_check_query_agent_engine_operator] - check_query_agent_engine = CheckQueryAgentEngineOperator( - task_id="check_query_agent_engine", - project_id=PROJECT_ID, - location=LOCATION, - operation_name=QUERY_OPERATION_NAME, - config={"retrieve_result": True}, + check_config={"retrieve_result": True}, deferrable=True, ) - # [END how_to_cloud_vertex_ai_check_query_agent_engine_operator] + # [END how_to_cloud_vertex_ai_run_query_job_operator] # [START how_to_cloud_vertex_ai_update_agent_engine_operator] update_agent_engine = UpdateAgentEngineOperator( @@ -151,7 +158,6 @@ location=LOCATION, agent_engine_id=AGENT_ENGINE_ID, force=True, - deferrable=True, trigger_rule=TriggerRule.ALL_DONE, ) # [END how_to_cloud_vertex_ai_delete_agent_engine_operator] @@ -167,8 +173,7 @@ create_agent_engine >> get_agent_engine >> create_bucket - >> query_agent_engine - >> check_query_agent_engine + >> run_query_job >> update_agent_engine >> delete_agent_engine >> delete_bucket diff --git a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py index 87c9984f03c3a..520d37b73c656 100644 --- a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py @@ -35,6 +35,8 @@ AGENT_ENGINE_NAME = "projects/test-project/locations/us-central1/reasoningEngines/123" OPERATION_NAME = "projects/test-project/locations/us-central1/operations/delete-123" QUERY_OPERATION_NAME = "projects/test-project/locations/us-central1/operations/query-123" +OPERATION_ID = "delete-123" +QUERY_OPERATION_ID = "query-123" CONFIG = {"display_name": "test-agent-engine"} QUERY_CONFIG = {"query": "hello", "output_gcs_uri": "gs://test-bucket/query-output/"} CHECK_QUERY_CONFIG = {"retrieve_result": True} @@ -70,6 +72,7 @@ def test_create_agent_engine(self, mock_get_client): mock_get_client.assert_called_once_with(self.hook, project_id=GCP_PROJECT, location=GCP_LOCATION) mock_get_client.return_value.create.assert_called_once_with( + agent_engine=None, agent=None, config=CONFIG, ) @@ -83,12 +86,24 @@ def test_get_agent_engine(self, mock_get_client): agent_engine_id=AGENT_ENGINE_ID, ) - mock_get_client.return_value.get.assert_called_once_with(name=AGENT_ENGINE_NAME) + mock_get_client.return_value.get.assert_called_once_with(name=AGENT_ENGINE_NAME, config=None) assert result == mock_get_client.return_value.get.return_value @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) - def test_query_agent_engine(self, mock_get_client): - result = self.hook.query_agent_engine( + def test_get_agent_engine_with_config(self, mock_get_client): + result = self.hook.get_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=CONFIG, + ) + + mock_get_client.return_value.get.assert_called_once_with(name=AGENT_ENGINE_NAME, config=CONFIG) + assert result == mock_get_client.return_value.get.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_run_query_job(self, mock_get_client): + result = self.hook.run_query_job( project_id=GCP_PROJECT, location=GCP_LOCATION, agent_engine_id=AGENT_ENGINE_ID, @@ -106,7 +121,7 @@ def test_check_query_agent_engine_job(self, mock_get_client): result = self.hook.check_query_agent_engine_job( project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + operation_id=QUERY_OPERATION_ID, config=CHECK_QUERY_CONFIG, ) @@ -123,7 +138,7 @@ def test_wait_for_query_agent_engine_job_returns_when_successful(self, mock_chec result = self.hook.wait_for_query_agent_engine_job( project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + operation_id=QUERY_OPERATION_ID, config=CHECK_QUERY_CONFIG, ) @@ -131,7 +146,7 @@ def test_wait_for_query_agent_engine_job_returns_when_successful(self, mock_chec self.hook, project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + operation_id=QUERY_OPERATION_ID, config=CHECK_QUERY_CONFIG, ) assert result == mock_check_query_job.return_value @@ -146,7 +161,7 @@ def test_wait_for_query_agent_engine_job_polls_until_success(self, mock_check_qu result = self.hook.wait_for_query_agent_engine_job( project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + operation_id=QUERY_OPERATION_ID, config=CHECK_QUERY_CONFIG, poll_interval=10, ) @@ -164,7 +179,7 @@ def test_wait_for_query_agent_engine_job_raises_on_failed_status(self, mock_chec self.hook.wait_for_query_agent_engine_job( project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + operation_id=QUERY_OPERATION_ID, config=CHECK_QUERY_CONFIG, ) @@ -182,7 +197,7 @@ def test_wait_for_query_agent_engine_job_raises_on_unexpected_status( self.hook.wait_for_query_agent_engine_job( project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + operation_id=QUERY_OPERATION_ID, config=CHECK_QUERY_CONFIG, ) @@ -201,7 +216,7 @@ def test_wait_for_query_agent_engine_job_times_out( self.hook.wait_for_query_agent_engine_job( project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + operation_id=QUERY_OPERATION_ID, config=CHECK_QUERY_CONFIG, timeout=1, ) @@ -220,6 +235,25 @@ def test_update_agent_engine(self, mock_get_client): mock_get_client.return_value.update.assert_called_once_with( name=AGENT_ENGINE_NAME, agent=None, + agent_engine=None, + config=CONFIG, + ) + assert result == mock_get_client.return_value.update.return_value + + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) + def test_update_agent_engine_with_agent_engine_alias(self, mock_get_client): + result = self.hook.update_agent_engine( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + agent_engine=mock.sentinel.agent_engine, + config=CONFIG, + ) + + mock_get_client.return_value.update.assert_called_once_with( + name=AGENT_ENGINE_NAME, + agent=None, + agent_engine=mock.sentinel.agent_engine, config=CONFIG, ) assert result == mock_get_client.return_value.update.return_value @@ -247,30 +281,51 @@ def test_get_agent_engine_operation(self, mock_session): mock_session.return_value.get.return_value.json.return_value = {"name": OPERATION_NAME, "done": True} result = self.hook.get_agent_engine_operation( + project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=OPERATION_NAME, + operation_id=OPERATION_ID, ) mock_session.assert_called_once_with(mock.sentinel.credentials) mock_session.return_value.get.assert_called_once_with( - f"https://{GCP_LOCATION}-aiplatform.googleapis.com/v1beta1/{OPERATION_NAME}" + f"https://{GCP_LOCATION}-aiplatform.googleapis.com/v1beta1/{OPERATION_NAME}", + timeout=60.0, ) mock_session.return_value.get.return_value.raise_for_status.assert_called_once_with() assert result == {"name": OPERATION_NAME, "done": True} + @mock.patch(AGENT_ENGINE_STRING.format("google.auth.transport.requests.AuthorizedSession"), autospec=True) + def test_get_agent_engine_operation_with_request_timeout(self, mock_session): + self.hook.get_credentials = mock.Mock(return_value=mock.sentinel.credentials, spec=()) + mock_session.return_value.get.return_value.json.return_value = {"name": OPERATION_NAME, "done": True} + + self.hook.get_agent_engine_operation( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + operation_id=OPERATION_ID, + request_timeout=10, + ) + + mock_session.return_value.get.assert_called_once_with( + f"https://{GCP_LOCATION}-aiplatform.googleapis.com/v1beta1/{OPERATION_NAME}", + timeout=10, + ) + @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_operation"), autospec=True) def test_wait_for_agent_engine_operation_returns_when_done(self, mock_get_operation): mock_get_operation.return_value = {"name": OPERATION_NAME, "done": True} self.hook.wait_for_agent_engine_operation( + project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=OPERATION_NAME, + operation_id=OPERATION_ID, ) mock_get_operation.assert_called_once_with( self.hook, + project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=OPERATION_NAME, + operation_id=OPERATION_ID, ) @mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True) @@ -281,8 +336,9 @@ def test_wait_for_agent_engine_operation_polls_until_done(self, mock_get_operati mock_get_operation.side_effect = [running_operation, running_operation, done_operation] self.hook.wait_for_agent_engine_operation( + project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=OPERATION_NAME, + operation_id=OPERATION_ID, poll_interval=10, ) @@ -296,8 +352,9 @@ def test_wait_for_agent_engine_operation_raises_on_error(self, mock_get_operatio with pytest.raises(RuntimeError, match="Agent Engine operation .* failed"): self.hook.wait_for_agent_engine_operation( + project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=OPERATION_NAME, + operation_id=OPERATION_ID, ) @mock.patch(AGENT_ENGINE_STRING.format("time.sleep"), autospec=True) @@ -309,8 +366,9 @@ def test_wait_for_agent_engine_operation_times_out(self, mock_get_operation, moc with pytest.raises(TimeoutError, match="Timed out waiting for Agent Engine operation"): self.hook.wait_for_agent_engine_operation( + project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=OPERATION_NAME, + operation_id=OPERATION_ID, timeout=1, ) @@ -332,13 +390,16 @@ async def test_get_agent_engine_operation_calls_sync_hook(self): self.hook.get_sync_hook = mock.AsyncMock(return_value=sync_hook) result = await self.hook.get_agent_engine_operation( + project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=OPERATION_NAME, + operation_id=OPERATION_ID, ) sync_hook.get_agent_engine_operation.assert_called_once_with( + project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=OPERATION_NAME, + operation_id=OPERATION_ID, + request_timeout=60.0, ) assert result == {"name": OPERATION_NAME, "done": True} @@ -351,14 +412,14 @@ async def test_check_query_agent_engine_job_calls_sync_hook(self): result = await self.hook.check_query_agent_engine_job( project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + operation_id=QUERY_OPERATION_ID, config=CHECK_QUERY_CONFIG, ) sync_hook.check_query_agent_engine_job.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + operation_id=QUERY_OPERATION_ID, config=CHECK_QUERY_CONFIG, ) assert result == mock.sentinel.query_job diff --git a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py index c852d685c6877..daf16ac46c7ac 100644 --- a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py @@ -23,11 +23,10 @@ from airflow.providers.common.compat.sdk import TaskDeferred from airflow.providers.google.cloud.operators.vertex_ai.agent_engine import ( - CheckQueryAgentEngineOperator, CreateAgentEngineOperator, DeleteAgentEngineOperator, GetAgentEngineOperator, - QueryAgentEngineOperator, + RunQueryJobOperator, UpdateAgentEngineOperator, ) @@ -45,6 +44,8 @@ CHECK_QUERY_CONFIG = {"retrieve_result": True} OPERATION = {"name": "operations/delete-123", "done": False} QUERY_OPERATION_NAME = "operations/query-123" +OPERATION_ID = "delete-123" +QUERY_OPERATION_ID = "query-123" class FakeModel: @@ -96,10 +97,36 @@ def test_execute(self, mock_hook, context): project_id=GCP_PROJECT, location=GCP_LOCATION, agent=None, + agent_engine=None, config=CONFIG, ) assert result == {"name": AGENT_ENGINE_NAME, "display_name": "test-agent-engine"} + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_with_agent_engine_alias(self, mock_hook, context): + mock_hook.return_value.create_agent_engine.return_value = FakeAgentEngine( + {"name": AGENT_ENGINE_NAME, "display_name": "test-agent-engine"} + ) + op = CreateAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine=mock.sentinel.agent_engine, + config=CONFIG, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + op.execute(context=context) + + mock_hook.return_value.create_agent_engine.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent=None, + agent_engine=mock.sentinel.agent_engine, + config=CONFIG, + ) + class TestGetAgentEngineOperator: @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) @@ -120,56 +147,56 @@ def test_execute(self, mock_hook, context): project_id=GCP_PROJECT, location=GCP_LOCATION, agent_engine_id=AGENT_ENGINE_ID, + config=None, ) assert result == {"name": AGENT_ENGINE_NAME} - -class TestQueryAgentEngineOperator: @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) - def test_execute(self, mock_hook, context): - result_payload = { - "job_name": "operations/query-123", - "input_gcs_uri": "gs://test-bucket/query-output/input.json", - "output_gcs_uri": "gs://test-bucket/query-output/output.json", - } - mock_hook.return_value.query_agent_engine.return_value = FakeModel(result_payload) - op = QueryAgentEngineOperator( + def test_execute_with_config(self, mock_hook, context): + mock_hook.return_value.get_agent_engine.return_value = FakeAgentEngine({"name": AGENT_ENGINE_NAME}) + op = GetAgentEngineOperator( task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, agent_engine_id=AGENT_ENGINE_ID, - config=QUERY_CONFIG, + config=CONFIG, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - result = op.execute(context=context) + op.execute(context=context) - mock_hook.return_value.query_agent_engine.assert_called_once_with( + mock_hook.return_value.get_agent_engine.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, agent_engine_id=AGENT_ENGINE_ID, - config=QUERY_CONFIG, + config=CONFIG, ) - assert result == result_payload -class TestCheckQueryAgentEngineOperator: +class TestRunQueryJobOperator: @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) def test_execute(self, mock_hook, context): - result_payload = { + run_result_payload = { + "job_name": "operations/query-123", + "input_gcs_uri": "gs://test-bucket/query-output/input.json", + "output_gcs_uri": "gs://test-bucket/query-output/output.json", + } + query_result_payload = { "operation_name": QUERY_OPERATION_NAME, "output_gcs_uri": "gs://test-bucket/query-output/output.json", "status": "SUCCESS", "result": "done", } - mock_hook.return_value.wait_for_query_agent_engine_job.return_value = FakeModel(result_payload) - op = CheckQueryAgentEngineOperator( + mock_hook.return_value.run_query_job.return_value = FakeModel(run_result_payload) + mock_hook.return_value.wait_for_query_agent_engine_job.return_value = FakeModel(query_result_payload) + op = RunQueryJobOperator( task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, - config=CHECK_QUERY_CONFIG, + agent_engine_id=AGENT_ENGINE_ID, + config=QUERY_CONFIG, + check_config=CHECK_QUERY_CONFIG, poll_interval=1, timeout=60, gcp_conn_id=GCP_CONN_ID, @@ -178,25 +205,57 @@ def test_execute(self, mock_hook, context): result = op.execute(context=context) + mock_hook.return_value.run_query_job.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=QUERY_CONFIG, + ) mock_hook.return_value.wait_for_query_agent_engine_job.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + operation_id=QUERY_OPERATION_ID, config=CHECK_QUERY_CONFIG, poll_interval=1, timeout=60, ) + assert result == query_result_payload + + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_without_wait(self, mock_hook, context): + result_payload = { + "job_name": "operations/query-123", + "input_gcs_uri": "gs://test-bucket/query-output/input.json", + "output_gcs_uri": "gs://test-bucket/query-output/output.json", + } + mock_hook.return_value.run_query_job.return_value = FakeModel(result_payload) + op = RunQueryJobOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + config=CHECK_QUERY_CONFIG, + agent_engine_id=AGENT_ENGINE_ID, + wait_for_completion=False, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + result = op.execute(context=context) + + mock_hook.return_value.wait_for_query_agent_engine_job.assert_not_called() assert result == result_payload @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineQueryJobTrigger"), autospec=True) @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) def test_execute_deferrable(self, mock_hook, mock_trigger, context): - op = CheckQueryAgentEngineOperator( + mock_hook.return_value.run_query_job.return_value = FakeModel({"job_name": QUERY_OPERATION_NAME}) + op = RunQueryJobOperator( task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + agent_engine_id=AGENT_ENGINE_ID, config=CHECK_QUERY_CONFIG, + check_config=CHECK_QUERY_CONFIG, poll_interval=1, timeout=60, gcp_conn_id=GCP_CONN_ID, @@ -211,7 +270,7 @@ def test_execute_deferrable(self, mock_hook, mock_trigger, context): mock_trigger.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + operation_id=QUERY_OPERATION_ID, config=CHECK_QUERY_CONFIG, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, @@ -219,12 +278,27 @@ def test_execute_deferrable(self, mock_hook, mock_trigger, context): timeout=60, ) + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_raises_when_query_job_has_no_name(self, mock_hook, context): + mock_hook.return_value.run_query_job.return_value = FakeModel({}) + op = RunQueryJobOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + ) + + with pytest.raises(RuntimeError, match="Agent Engine query job did not include an operation name."): + op.execute(context=context) + + mock_hook.return_value.wait_for_query_agent_engine_job.assert_not_called() + def test_execute_complete_success(self, context): - op = CheckQueryAgentEngineOperator( + op = RunQueryJobOperator( task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) query_job = {"operation_name": QUERY_OPERATION_NAME, "status": "SUCCESS"} @@ -236,33 +310,33 @@ def test_execute_complete_success(self, context): assert result == query_job def test_execute_complete_error(self, context): - op = CheckQueryAgentEngineOperator( + op = RunQueryJobOperator( task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) with pytest.raises(RuntimeError, match="boom"): op.execute_complete(context=context, event={"status": "error", "message": "boom"}) def test_execute_complete_timeout(self, context): - op = CheckQueryAgentEngineOperator( + op = RunQueryJobOperator( task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) with pytest.raises(TimeoutError, match="timed out"): op.execute_complete(context=context, event={"status": "timeout", "message": "timed out"}) def test_execute_complete_without_event(self, context): - op = CheckQueryAgentEngineOperator( + op = RunQueryJobOperator( task_id=TASK_ID, project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + agent_engine_id=AGENT_ENGINE_ID, ) with pytest.raises(RuntimeError, match="No event received in trigger callback"): @@ -292,10 +366,38 @@ def test_execute(self, mock_hook, context): location=GCP_LOCATION, agent_engine_id=AGENT_ENGINE_ID, agent=None, + agent_engine=None, config=CONFIG, ) assert result == {"name": AGENT_ENGINE_NAME, "display_name": "updated-agent-engine"} + @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) + def test_execute_with_agent_engine_alias(self, mock_hook, context): + mock_hook.return_value.update_agent_engine.return_value = FakeAgentEngine( + {"name": AGENT_ENGINE_NAME, "display_name": "updated-agent-engine"} + ) + op = UpdateAgentEngineOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + agent_engine=mock.sentinel.agent_engine, + config=CONFIG, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + op.execute(context=context) + + mock_hook.return_value.update_agent_engine.assert_called_once_with( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + agent=None, + agent_engine=mock.sentinel.agent_engine, + config=CONFIG, + ) + class TestDeleteAgentEngineOperator: @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) @@ -343,8 +445,9 @@ def test_execute_waits_until_deleted(self, mock_hook, context): result = op.execute(context=context) mock_hook.return_value.wait_for_agent_engine_operation.assert_called_once_with( + project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=OPERATION["name"], + operation_id=OPERATION_ID, poll_interval=1, timeout=60, ) @@ -388,90 +491,3 @@ def test_execute_raises_when_delete_operation_has_no_name(self, mock_hook, conte op.execute(context=context) mock_hook.return_value.wait_for_agent_engine_operation.assert_not_called() - - @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineDeleteTrigger"), autospec=True) - @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) - def test_execute_deferrable(self, mock_hook, mock_trigger, context): - mock_hook.return_value.delete_agent_engine.return_value = FakeModel(OPERATION) - op = DeleteAgentEngineOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - wait_for_completion=True, - poll_interval=1, - timeout=60, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - deferrable=True, - ) - - with pytest.raises(TaskDeferred) as exc: - op.execute(context=context) - - mock_trigger.assert_called_once_with( - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - poll_interval=1, - timeout=60, - operation_name=OPERATION["name"], - ) - assert exc.value.kwargs == {"operation": OPERATION} - - def test_execute_complete_success(self, context): - op = DeleteAgentEngineOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - ) - - result = op.execute_complete( - context=context, - event={ - "status": "success", - "message": "Agent Engine deleted", - "agent_engine_id": AGENT_ENGINE_ID, - }, - operation=OPERATION, - ) - - assert result == OPERATION - - def test_execute_complete_error(self, context): - op = DeleteAgentEngineOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - ) - - with pytest.raises(RuntimeError, match="boom"): - op.execute_complete(context=context, event={"status": "error", "message": "boom"}) - - def test_execute_complete_timeout(self, context): - op = DeleteAgentEngineOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - ) - - with pytest.raises(TimeoutError, match="timed out"): - op.execute_complete( - context=context, - event={"status": "timeout", "message": "timed out", "agent_engine_id": AGENT_ENGINE_ID}, - ) - - def test_execute_complete_without_event(self, context): - op = DeleteAgentEngineOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - ) - - with pytest.raises(RuntimeError, match="No event received in trigger callback"): - op.execute_complete(context=context) diff --git a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py index f823b35cfd9c3..dfb2b418be5d9 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_vertex_ai_agent_engine.py @@ -21,10 +21,7 @@ import pytest -from airflow.providers.google.cloud.triggers.vertex_ai import ( - AgentEngineDeleteTrigger, - AgentEngineQueryJobTrigger, -) +from airflow.providers.google.cloud.triggers.vertex_ai import AgentEngineQueryJobTrigger from airflow.triggers.base import TriggerEvent GCP_PROJECT = "test-project" @@ -32,8 +29,8 @@ GCP_CONN_ID = "test-conn" IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] AGENT_ENGINE_ID = "123" -OPERATION_NAME = "projects/test-project/locations/us-central1/operations/delete-123" QUERY_OPERATION_NAME = "projects/test-project/locations/us-central1/operations/query-123" +QUERY_OPERATION_ID = "query-123" CHECK_QUERY_CONFIG = {"retrieve_result": True} @@ -47,25 +44,12 @@ def model_dump(self, mode="json"): return self.payload -@pytest.fixture -def delete_trigger(): - return AgentEngineDeleteTrigger( - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - poll_interval=1, - timeout=60, - operation_name=OPERATION_NAME, - ) - - @pytest.fixture def query_job_trigger(): return AgentEngineQueryJobTrigger( project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + operation_id=QUERY_OPERATION_ID, config=CHECK_QUERY_CONFIG, gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, @@ -74,115 +58,6 @@ def query_job_trigger(): ) -class TestAgentEngineDeleteTrigger: - def test_serialize(self, delete_trigger): - assert delete_trigger.serialize() == ( - "airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineDeleteTrigger", - { - "location": GCP_LOCATION, - "agent_engine_id": AGENT_ENGINE_ID, - "gcp_conn_id": GCP_CONN_ID, - "impersonation_chain": IMPERSONATION_CHAIN, - "poll_interval": 1, - "timeout": 60, - "operation_name": OPERATION_NAME, - }, - ) - - @pytest.mark.asyncio - @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) - async def test_run_loop_return_success_event(self, mock_hook, delete_trigger): - mock_hook.return_value.get_agent_engine_operation.return_value = {"done": True} - - event = await delete_trigger.run().asend(None) - - mock_hook.return_value.get_agent_engine_operation.assert_called_once_with( - location=GCP_LOCATION, - operation_name=OPERATION_NAME, - ) - assert event == TriggerEvent( - { - "status": "success", - "message": "Agent Engine deleted", - "agent_engine_id": AGENT_ENGINE_ID, - } - ) - - @pytest.mark.asyncio - @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.asyncio.sleep", autospec=True) - @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) - async def test_run_loop_polls_until_success(self, mock_hook, mock_sleep, delete_trigger): - mock_hook.return_value.get_agent_engine_operation.side_effect = [ - {"done": False}, - {"done": False}, - {"done": True}, - ] - - event = await delete_trigger.run().asend(None) - - assert mock_hook.return_value.get_agent_engine_operation.call_count == 3 - assert mock_sleep.call_count == 2 - mock_sleep.assert_awaited_with(1) - assert event == TriggerEvent( - { - "status": "success", - "message": "Agent Engine deleted", - "agent_engine_id": AGENT_ENGINE_ID, - } - ) - - @pytest.mark.asyncio - @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.asyncio.sleep", autospec=True) - @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) - async def test_run_loop_return_timeout_event(self, mock_hook, mock_sleep, delete_trigger): - delete_trigger.timeout = -1 - mock_hook.return_value.get_agent_engine_operation.return_value = {"done": False} - - event = await delete_trigger.run().asend(None) - - mock_sleep.assert_not_called() - assert event == TriggerEvent( - { - "status": "timeout", - "message": f"Timed out waiting for Agent Engine {AGENT_ENGINE_ID} to be deleted", - "agent_engine_id": AGENT_ENGINE_ID, - } - ) - - @pytest.mark.asyncio - @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) - async def test_run_loop_return_error_event(self, mock_hook, delete_trigger): - mock_hook.return_value.get_agent_engine_operation.side_effect = RuntimeError("boom") - - event = await delete_trigger.run().asend(None) - - assert event == TriggerEvent( - { - "status": "error", - "message": "Failed while polling Agent Engine deletion: boom", - "agent_engine_id": AGENT_ENGINE_ID, - } - ) - - @pytest.mark.asyncio - @mock.patch("airflow.providers.google.cloud.triggers.vertex_ai.AgentEngineAsyncHook", autospec=True) - async def test_run_loop_return_operation_error_event(self, mock_hook, delete_trigger): - mock_hook.return_value.get_agent_engine_operation.return_value = { - "done": True, - "error": {"message": "boom"}, - } - - event = await delete_trigger.run().asend(None) - - assert event == TriggerEvent( - { - "status": "error", - "message": "{'message': 'boom'}", - "agent_engine_id": AGENT_ENGINE_ID, - } - ) - - class TestAgentEngineQueryJobTrigger: def test_serialize(self, query_job_trigger): assert query_job_trigger.serialize() == ( @@ -190,7 +65,7 @@ def test_serialize(self, query_job_trigger): { "project_id": GCP_PROJECT, "location": GCP_LOCATION, - "operation_name": QUERY_OPERATION_NAME, + "operation_id": QUERY_OPERATION_ID, "config": CHECK_QUERY_CONFIG, "gcp_conn_id": GCP_CONN_ID, "impersonation_chain": IMPERSONATION_CHAIN, @@ -204,7 +79,7 @@ def test_serialize_with_pydantic_config(self): trigger = AgentEngineQueryJobTrigger( project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + operation_id=QUERY_OPERATION_ID, config=pydantic_config, gcp_conn_id=GCP_CONN_ID, poll_interval=1, @@ -228,7 +103,7 @@ async def test_run_loop_return_success_event(self, mock_hook, query_job_trigger) mock_hook.return_value.check_query_agent_engine_job.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, - operation_name=QUERY_OPERATION_NAME, + operation_id=QUERY_OPERATION_ID, config=CHECK_QUERY_CONFIG, ) assert event == TriggerEvent( @@ -275,7 +150,7 @@ async def test_run_loop_return_failed_event(self, mock_hook, query_job_trigger): assert event == TriggerEvent( { "status": "error", - "message": f"Agent Engine query job {QUERY_OPERATION_NAME} failed.", + "message": f"Agent Engine query job {QUERY_OPERATION_ID} failed.", "query_job": query_job, } ) @@ -293,7 +168,7 @@ async def test_run_loop_return_timeout_event(self, mock_hook, mock_sleep, query_ assert event == TriggerEvent( { "status": "timeout", - "message": f"Timed out waiting for Agent Engine query job {QUERY_OPERATION_NAME}", + "message": f"Timed out waiting for Agent Engine query job {QUERY_OPERATION_ID}", "query_job": {"status": "RUNNING"}, } ) @@ -310,8 +185,7 @@ async def test_run_loop_return_error_event_for_unexpected_status(self, mock_hook { "status": "error", "message": ( - f"Agent Engine query job {QUERY_OPERATION_NAME} completed with " - "unexpected status CANCELLED." + f"Agent Engine query job {QUERY_OPERATION_ID} completed with unexpected status CANCELLED." ), "query_job": query_job, } @@ -328,6 +202,6 @@ async def test_run_loop_return_error_event(self, mock_hook, query_job_trigger): { "status": "error", "message": "Failed while polling Agent Engine query job: boom", - "query_job": {"operation_name": QUERY_OPERATION_NAME}, + "query_job": {"operation_id": QUERY_OPERATION_ID}, } ) From 91df057f61c35baa71132e0bbbacfe61e3095e4d Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Fri, 26 Jun 2026 17:34:24 -0300 Subject: [PATCH 26/32] Fix mypy type error in RunQueryJobOperator wait assignment --- .../google/cloud/operators/vertex_ai/agent_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py index f576c3f06905c..6eb4643b17317 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -266,7 +266,7 @@ def execute(self, context: Context) -> dict[str, Any]: method_name="execute_complete", ) - query_job = self.hook.wait_for_query_agent_engine_job( + completed_job = self.hook.wait_for_query_agent_engine_job( project_id=self.project_id, location=self.location, operation_id=operation_id, @@ -275,7 +275,7 @@ def execute(self, context: Context) -> dict[str, Any]: timeout=self.timeout, ) self.log.info("Agent Engine query job %s completed.", operation_name) - return serialize_value(query_job) + return serialize_value(completed_job) def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]: if event is None: From cccc3cddbb254ffccb1283c07b5fa065e1b71b70 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Sat, 27 Jun 2026 14:54:11 -0300 Subject: [PATCH 27/32] Simplify Vertex AI Agent Engine system test setup --- .../example_vertex_ai_agent_engine.py | 56 +++++-------------- 1 file changed, 14 insertions(+), 42 deletions(-) diff --git a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py index 4be41d802c998..9a6dbf4bef805 100644 --- a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py +++ b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py @@ -30,7 +30,6 @@ if TYPE_CHECKING: from vertexai._genai import types as vertexai_types -from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator from airflow.providers.google.cloud.operators.vertex_ai.agent_engine import ( CreateAgentEngineOperator, DeleteAgentEngineOperator, @@ -54,26 +53,21 @@ "SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_CONTAINER_URI", "us-central1-docker.pkg.dev/example-project/example-repository/example-agent:latest", ) +CONTAINER_ENV_VARS: dict[str, str] = json.loads( + os.environ.get("SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_CONTAINER_ENV_VARS", "{}") +) +QUERY_STR = os.environ.get("SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_QUERY", "test query") +QUERY_OUTPUT_GCS_URI = os.environ.get( + "SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_QUERY_OUTPUT_GCS_URI", + f"gs://{PROJECT_ID}-agent-engine-query-output/query-output/", +) AGENT_ENGINE_ID = "{{ task_instance.xcom_pull(task_ids='create_agent_engine')['name'].split('/')[-1] }}" -BUCKET_NAME = f"bucket-{DAG_ID}-{ENV_ID}".replace("_", "-") DISPLAY_NAME = f"airflow-agent-engine-{ENV_ID}" QUERY_CONFIG: vertexai_types.RunQueryJobAgentEngineConfigDict = { - "query": json.dumps( - { - "dag_id": "vertex_ai_agent_engine_operations", - "run_id": "manual__agent_engine_system_test", - "dag_file": "providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py", - "failed_task": { - "task_id": "transform", - "state": "failed", - "try_number": 1, - }, - "log_excerpt": "KeyError: 'rowz'", - } - ), - "output_gcs_uri": f"gs://{BUCKET_NAME}/query-output/", + "query": QUERY_STR, + "output_gcs_uri": QUERY_OUTPUT_GCS_URI, } @@ -97,10 +91,7 @@ "max_instances": 1, "resource_limits": {"cpu": "1", "memory": "1Gi"}, "container_spec": {"image_uri": CONTAINER_URI}, - "env_vars": { - "AGENT_USE_MODEL": "false", - "AGENT_USE_MOCKS": "true", - }, + **({"env_vars": CONTAINER_ENV_VARS} if CONTAINER_ENV_VARS else {}), "class_methods": [ { "name": "query", @@ -120,12 +111,6 @@ ) # [END how_to_cloud_vertex_ai_get_agent_engine_operator] - create_bucket = GCSCreateBucketOperator( - task_id="create_bucket", - bucket_name=BUCKET_NAME, - project_id=PROJECT_ID, - ) - # [START how_to_cloud_vertex_ai_run_query_job_operator] run_query_job = RunQueryJobOperator( task_id="run_query_job", @@ -134,6 +119,8 @@ agent_engine_id=AGENT_ENGINE_ID, config=QUERY_CONFIG, check_config={"retrieve_result": True}, + poll_interval=10, + timeout=900, deferrable=True, ) # [END how_to_cloud_vertex_ai_run_query_job_operator] @@ -162,22 +149,7 @@ ) # [END how_to_cloud_vertex_ai_delete_agent_engine_operator] - delete_bucket = GCSDeleteBucketOperator( - task_id="delete_bucket", - bucket_name=BUCKET_NAME, - force=True, - trigger_rule=TriggerRule.ALL_DONE, - ) - - ( - create_agent_engine - >> get_agent_engine - >> create_bucket - >> run_query_job - >> update_agent_engine - >> delete_agent_engine - >> delete_bucket - ) + (create_agent_engine >> get_agent_engine >> run_query_job >> update_agent_engine >> delete_agent_engine) # ### Everything below this line is not part of example ### # ### Just for system tests purpose ### From 140ee309773c28c89ce2f762ce5256bf5d3d8195 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Mon, 29 Jun 2026 15:30:38 -0300 Subject: [PATCH 28/32] Update Vertex AI Agent Engine system test --- .../google/docs/operators/cloud/vertex_ai.rst | 9 ++- .../cloud/hooks/vertex_ai/agent_engine.py | 8 +- .../cloud/operators/vertex_ai/agent_engine.py | 10 --- .../example_vertex_ai_agent_engine.py | 75 ++++++++++++++++--- .../hooks/vertex_ai/test_agent_engine.py | 20 ----- .../operators/vertex_ai/test_agent_engine.py | 54 ------------- 6 files changed, 76 insertions(+), 100 deletions(-) diff --git a/providers/google/docs/operators/cloud/vertex_ai.rst b/providers/google/docs/operators/cloud/vertex_ai.rst index 96710185a922f..f1ba0e652d315 100644 --- a/providers/google/docs/operators/cloud/vertex_ai.rst +++ b/providers/google/docs/operators/cloud/vertex_ai.rst @@ -54,7 +54,6 @@ can include ``query`` and ``output_gcs_uri``. The SDK writes query input and out through Google Cloud Storage. By default, the operator waits for the query job to complete and returns the serialized query job result. Set ``retrieve_result`` to ``True`` in ``check_config`` to return the query job result from Google Cloud Storage. -The same operation can be performed in the deferrable mode. .. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py :language: python @@ -62,6 +61,14 @@ The same operation can be performed in the deferrable mode. :start-after: [START how_to_cloud_vertex_ai_run_query_job_operator] :end-before: [END how_to_cloud_vertex_ai_run_query_job_operator] +The same operation can be performed in the deferrable mode. + +.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_run_query_job_operator_deferrable] + :end-before: [END how_to_cloud_vertex_ai_run_query_job_operator_deferrable] + To update an Agent Engine you can use :class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.UpdateAgentEngineOperator`. diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index a0e016d192729..591b4ef30b1f2 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -104,7 +104,6 @@ def build_operation_name(project_id: str, location: str, operation_id: str) -> s def create_agent_engine( self, location: str, - agent_engine: Any | None = None, agent: Any | None = None, config: types.AgentEngineConfigOrDict | None = None, project_id: str = PROVIDE_PROJECT_ID, @@ -113,14 +112,13 @@ def create_agent_engine( Create an Agent Engine. :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param agent_engine: Optional. Deprecated alias for ``agent``. :param agent: Optional. The agent object to deploy. :param config: Optional. Configuration for the Agent Engine. :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project configured in the connection. """ client = self.get_agent_engine_client(project_id=project_id, location=location) - return client.create(agent_engine=agent_engine, agent=agent, config=config) + return client.create(agent=agent, config=config) @GoogleBaseHook.fallback_to_default_project_id def get_agent_engine( @@ -236,7 +234,6 @@ def update_agent_engine( agent_engine_id: str, config: types.AgentEngineConfigOrDict, agent: Any | None = None, - agent_engine: Any | None = None, project_id: str = PROVIDE_PROJECT_ID, ) -> types.AgentEngine: """ @@ -246,13 +243,12 @@ def update_agent_engine( :param agent_engine_id: Required. The Agent Engine ID. :param config: Required. Configuration for the Agent Engine update. :param agent: Optional. The updated agent object to deploy. - :param agent_engine: Optional. Deprecated alias for ``agent``. :param project_id: Optional. The ID of the Google Cloud project. Defaults to the project configured in the connection. """ client = self.get_agent_engine_client(project_id=project_id, location=location) name = self.build_agent_engine_name(project_id, location, agent_engine_id) - return client.update(name=name, agent=agent, agent_engine=agent_engine, config=config) + return client.update(name=name, agent=agent, config=config) @GoogleBaseHook.fallback_to_default_project_id def delete_agent_engine( diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py index 6eb4643b17317..63382be83e042 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/agent_engine.py @@ -52,7 +52,6 @@ class CreateAgentEngineOperator(GoogleCloudBaseOperator): :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param location: Required. The ID of the Google Cloud location that the service belongs to. :param agent: Optional. The agent object to deploy. - :param agent_engine: Optional. Deprecated alias for ``agent``. :param config: Optional. Configuration for the Agent Engine. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials. @@ -62,7 +61,6 @@ class CreateAgentEngineOperator(GoogleCloudBaseOperator): "project_id", "location", "agent", - "agent_engine", "config", "gcp_conn_id", "impersonation_chain", @@ -74,7 +72,6 @@ def __init__( project_id: str, location: str, agent: Any | None = None, - agent_engine: Any | None = None, config: types.AgentEngineConfigOrDict | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, @@ -84,7 +81,6 @@ def __init__( self.project_id = project_id self.location = location self.agent = agent - self.agent_engine = agent_engine self.config = config self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain @@ -102,7 +98,6 @@ def execute(self, context: Context) -> dict[str, Any]: project_id=self.project_id, location=self.location, agent=self.agent, - agent_engine=self.agent_engine, config=self.config, ) result = _serialize_agent_engine(agent_engine) @@ -296,7 +291,6 @@ class UpdateAgentEngineOperator(GoogleCloudBaseOperator): :param location: Required. The ID of the Google Cloud location that the service belongs to. :param agent_engine_id: Required. The Agent Engine ID. :param agent: Optional. The updated agent object to deploy. - :param agent_engine: Optional. Deprecated alias for ``agent``. :param config: Required. Configuration for the Agent Engine update. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials. @@ -307,7 +301,6 @@ class UpdateAgentEngineOperator(GoogleCloudBaseOperator): "location", "agent_engine_id", "agent", - "agent_engine", "config", "gcp_conn_id", "impersonation_chain", @@ -321,7 +314,6 @@ def __init__( agent_engine_id: str, config: types.AgentEngineConfigOrDict, agent: Any | None = None, - agent_engine: Any | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, **kwargs, @@ -331,7 +323,6 @@ def __init__( self.location = location self.agent_engine_id = agent_engine_id self.agent = agent - self.agent_engine = agent_engine self.config = config self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain @@ -350,7 +341,6 @@ def execute(self, context: Context) -> dict[str, Any]: location=self.location, agent_engine_id=self.agent_engine_id, agent=self.agent, - agent_engine=self.agent_engine, config=self.config, ) result = _serialize_agent_engine(agent_engine) diff --git a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py index 9a6dbf4bef805..9d856731ad9a1 100644 --- a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py +++ b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py @@ -46,17 +46,54 @@ from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef,attr-defined] DAG_ID = "vertex_ai_agent_engine_operations" -LOCATION = "us-central1" ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") -CONTAINER_URI = os.environ.get( - "SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_CONTAINER_URI", + + +def _get_env(name: str, default: str = "") -> str: + return os.environ.get(name) or os.environ.get(f"AIRFLOW_VAR_{name}", default) + + +def _get_json_env(name: str, default: dict[str, str]) -> dict[str, str]: + value = os.environ.get(name) + return json.loads(value) if value else default + + +LOCATION = _get_env("GCP_REGION", "us-central1") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or _get_env("GCP_PROJECT_ID", "default") +CONTAINER_URI = os.environ.get("SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_CONTAINER_URI") or _get_env( + "GCP_AGENT_ENGINE_CONTAINER_URI", "us-central1-docker.pkg.dev/example-project/example-repository/example-agent:latest", ) -CONTAINER_ENV_VARS: dict[str, str] = json.loads( - os.environ.get("SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_CONTAINER_ENV_VARS", "{}") +CONTAINER_ENV_VARS: dict[str, str] = _get_json_env( + "SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_CONTAINER_ENV_VARS", + {}, +) +if not CONTAINER_ENV_VARS: + CONTAINER_ENV_VARS = { + "GCP_PROJECT": PROJECT_ID, + "GCP_REGION": LOCATION, + "GEMINI_MODEL_ID": _get_env("GEMINI_MODEL_ID", "gemini-2.5-pro"), + "GITHUB_REPO": _get_env("GITHUB_REPO"), + "GITHUB_REF": _get_env("GITHUB_REF", "main"), + "GITHUB_DAG_PATH": _get_env("GITHUB_DAG_PATH", "airflow/dags"), + "GITHUB_TOKEN": _get_env("GITHUB_TOKEN"), + "SLACK_WEBHOOK_URL": _get_env("SLACK_WEBHOOK_URL"), + "AGENT_USE_MODEL": _get_env("AGENT_USE_MODEL", "true"), + "AGENT_USE_MOCKS": _get_env("AGENT_USE_MOCKS", "false"), + } +QUERY_STR = os.environ.get("SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_QUERY") or json.dumps( + { + "dag_id": "gcp_agentengine_demo_failing_etl", + "run_id": "manual__agentengine_smoke", + "dag_file": "gcp_gemini_agent_platform/demo_failing_etl.py", + "failed_task": { + "task_id": "transform", + "state": "failed", + "try_number": 1, + }, + "log_excerpt": "KeyError: 'rowz'", + } ) -QUERY_STR = os.environ.get("SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_QUERY", "test query") QUERY_OUTPUT_GCS_URI = os.environ.get( "SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_QUERY_OUTPUT_GCS_URI", f"gs://{PROJECT_ID}-agent-engine-query-output/query-output/", @@ -121,10 +158,23 @@ check_config={"retrieve_result": True}, poll_interval=10, timeout=900, - deferrable=True, ) # [END how_to_cloud_vertex_ai_run_query_job_operator] + # [START how_to_cloud_vertex_ai_run_query_job_operator_deferrable] + run_query_job_deferrable = RunQueryJobOperator( + task_id="run_query_job_deferrable", + project_id=PROJECT_ID, + location=LOCATION, + agent_engine_id=AGENT_ENGINE_ID, + config=QUERY_CONFIG, + check_config={"retrieve_result": True}, + poll_interval=10, + timeout=900, + deferrable=True, + ) + # [END how_to_cloud_vertex_ai_run_query_job_operator_deferrable] + # [START how_to_cloud_vertex_ai_update_agent_engine_operator] update_agent_engine = UpdateAgentEngineOperator( task_id="update_agent_engine", @@ -149,7 +199,14 @@ ) # [END how_to_cloud_vertex_ai_delete_agent_engine_operator] - (create_agent_engine >> get_agent_engine >> run_query_job >> update_agent_engine >> delete_agent_engine) + ( + create_agent_engine + >> get_agent_engine + >> run_query_job + >> run_query_job_deferrable + >> update_agent_engine + >> delete_agent_engine + ) # ### Everything below this line is not part of example ### # ### Just for system tests purpose ### diff --git a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py index 520d37b73c656..a3be32b7222da 100644 --- a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py @@ -72,7 +72,6 @@ def test_create_agent_engine(self, mock_get_client): mock_get_client.assert_called_once_with(self.hook, project_id=GCP_PROJECT, location=GCP_LOCATION) mock_get_client.return_value.create.assert_called_once_with( - agent_engine=None, agent=None, config=CONFIG, ) @@ -235,25 +234,6 @@ def test_update_agent_engine(self, mock_get_client): mock_get_client.return_value.update.assert_called_once_with( name=AGENT_ENGINE_NAME, agent=None, - agent_engine=None, - config=CONFIG, - ) - assert result == mock_get_client.return_value.update.return_value - - @mock.patch(AGENT_ENGINE_STRING.format("AgentEngineHook.get_agent_engine_client"), autospec=True) - def test_update_agent_engine_with_agent_engine_alias(self, mock_get_client): - result = self.hook.update_agent_engine( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - agent_engine=mock.sentinel.agent_engine, - config=CONFIG, - ) - - mock_get_client.return_value.update.assert_called_once_with( - name=AGENT_ENGINE_NAME, - agent=None, - agent_engine=mock.sentinel.agent_engine, config=CONFIG, ) assert result == mock_get_client.return_value.update.return_value diff --git a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py index daf16ac46c7ac..4f67bba9fe5f6 100644 --- a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_agent_engine.py @@ -97,36 +97,10 @@ def test_execute(self, mock_hook, context): project_id=GCP_PROJECT, location=GCP_LOCATION, agent=None, - agent_engine=None, config=CONFIG, ) assert result == {"name": AGENT_ENGINE_NAME, "display_name": "test-agent-engine"} - @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) - def test_execute_with_agent_engine_alias(self, mock_hook, context): - mock_hook.return_value.create_agent_engine.return_value = FakeAgentEngine( - {"name": AGENT_ENGINE_NAME, "display_name": "test-agent-engine"} - ) - op = CreateAgentEngineOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine=mock.sentinel.agent_engine, - config=CONFIG, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - - op.execute(context=context) - - mock_hook.return_value.create_agent_engine.assert_called_once_with( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent=None, - agent_engine=mock.sentinel.agent_engine, - config=CONFIG, - ) - class TestGetAgentEngineOperator: @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) @@ -366,38 +340,10 @@ def test_execute(self, mock_hook, context): location=GCP_LOCATION, agent_engine_id=AGENT_ENGINE_ID, agent=None, - agent_engine=None, config=CONFIG, ) assert result == {"name": AGENT_ENGINE_NAME, "display_name": "updated-agent-engine"} - @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) - def test_execute_with_agent_engine_alias(self, mock_hook, context): - mock_hook.return_value.update_agent_engine.return_value = FakeAgentEngine( - {"name": AGENT_ENGINE_NAME, "display_name": "updated-agent-engine"} - ) - op = UpdateAgentEngineOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - agent_engine=mock.sentinel.agent_engine, - config=CONFIG, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - - op.execute(context=context) - - mock_hook.return_value.update_agent_engine.assert_called_once_with( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - agent_engine_id=AGENT_ENGINE_ID, - agent=None, - agent_engine=mock.sentinel.agent_engine, - config=CONFIG, - ) - class TestDeleteAgentEngineOperator: @mock.patch(AGENT_ENGINE_PATH.format("AgentEngineHook"), autospec=True) From 178de31f9743f974286cf0b1c9fbcc9948f1913d Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Mon, 29 Jun 2026 17:44:56 -0300 Subject: [PATCH 29/32] Fix Vertex AI Agent Engine system test docs --- .../example_vertex_ai_agent_engine.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py index 9d856731ad9a1..c505052746557 100644 --- a/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py +++ b/providers/google/tests/system/google/cloud/vertex_ai/example_vertex_ai_agent_engine.py @@ -58,18 +58,11 @@ def _get_json_env(name: str, default: dict[str, str]) -> dict[str, str]: return json.loads(value) if value else default -LOCATION = _get_env("GCP_REGION", "us-central1") -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or _get_env("GCP_PROJECT_ID", "default") -CONTAINER_URI = os.environ.get("SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_CONTAINER_URI") or _get_env( - "GCP_AGENT_ENGINE_CONTAINER_URI", - "us-central1-docker.pkg.dev/example-project/example-repository/example-agent:latest", -) -CONTAINER_ENV_VARS: dict[str, str] = _get_json_env( - "SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_CONTAINER_ENV_VARS", - {}, -) -if not CONTAINER_ENV_VARS: - CONTAINER_ENV_VARS = { +def _get_container_env_vars() -> dict[str, str]: + return _get_json_env( + "SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_CONTAINER_ENV_VARS", + {}, + ) or { "GCP_PROJECT": PROJECT_ID, "GCP_REGION": LOCATION, "GEMINI_MODEL_ID": _get_env("GEMINI_MODEL_ID", "gemini-2.5-pro"), @@ -81,6 +74,15 @@ def _get_json_env(name: str, default: dict[str, str]) -> dict[str, str]: "AGENT_USE_MODEL": _get_env("AGENT_USE_MODEL", "true"), "AGENT_USE_MOCKS": _get_env("AGENT_USE_MOCKS", "false"), } + + +LOCATION = _get_env("GCP_REGION", "us-central1") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or _get_env("GCP_PROJECT_ID", "default") +CONTAINER_URI = os.environ.get("SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_CONTAINER_URI") or _get_env( + "GCP_AGENT_ENGINE_CONTAINER_URI", + "us-central1-docker.pkg.dev/example-project/example-repository/example-agent:latest", +) +CONTAINER_ENV_VARS = _get_container_env_vars() QUERY_STR = os.environ.get("SYSTEM_TESTS_VERTEX_AI_AGENT_ENGINE_QUERY") or json.dumps( { "dag_id": "gcp_agentengine_demo_failing_etl", From 6e0a23ffa9740788bfc471deb8c1fc7db2151133 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Wed, 1 Jul 2026 20:05:16 -0300 Subject: [PATCH 30/32] Avoid blocking Agent Engine operation polling --- .../cloud/hooks/vertex_ai/agent_engine.py | 23 ++++++++++-- .../resources/agent_engine/__init__.py | 16 ++++++++ .../hooks/vertex_ai/test_agent_engine.py | 37 +++++++++++++++---- 3 files changed, 64 insertions(+), 12 deletions(-) create mode 100644 providers/google/tests/system/google/cloud/vertex_ai/resources/agent_engine/__init__.py diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index 591b4ef30b1f2..a7ac846985772 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Any import google.auth.transport.requests +from aiohttp import ClientSession from asgiref.sync import sync_to_async from vertexai import Client @@ -367,12 +368,26 @@ async def get_agent_engine_operation( ) -> dict[str, Any]: """Return a Vertex AI Agent Engine long-running operation.""" sync_hook = await self.get_sync_hook() - return await sync_to_async(sync_hook.get_agent_engine_operation)( - project_id=project_id, + project_id = project_id or sync_hook.project_id + operation_name = self.sync_hook_class.build_operation_name(project_id, location, operation_id) + url = VERTEX_AI_AGENT_ENGINE_OPERATION_URL.format( location=location, - operation_id=operation_id, - request_timeout=request_timeout, + api_version=VERTEX_AI_AGENT_ENGINE_API_VERSION, + operation_name=operation_name, ) + credentials = sync_hook.get_credentials() + + if not credentials.valid: + credentials.refresh(google.auth.transport.requests.Request()) + + async with ClientSession() as session: + async with session.get( + url, + headers={"Authorization": f"Bearer {credentials.token}"}, + timeout=request_timeout, + ) as response: + response.raise_for_status() + return await response.json() async def check_query_agent_engine_job( self, diff --git a/providers/google/tests/system/google/cloud/vertex_ai/resources/agent_engine/__init__.py b/providers/google/tests/system/google/cloud/vertex_ai/resources/agent_engine/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/google/tests/system/google/cloud/vertex_ai/resources/agent_engine/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py index a3be32b7222da..ed4f0c7220188 100644 --- a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py +++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_agent_engine.py @@ -363,24 +363,45 @@ def setup_method(self): ): self.hook = AgentEngineAsyncHook(gcp_conn_id=TEST_GCP_CONN_ID) + @mock.patch(AGENT_ENGINE_STRING.format("ClientSession"), autospec=True) + @pytest.mark.parametrize( + ("valid_credentials", "expected_refresh_calls"), + [ + (False, 1), + (True, 0), + ], + ) @pytest.mark.asyncio - async def test_get_agent_engine_operation_calls_sync_hook(self): + async def test_get_agent_engine_operation( + self, mock_client_session, valid_credentials, expected_refresh_calls + ): sync_hook = mock.Mock(spec=AgentEngineHook) - sync_hook.get_agent_engine_operation.return_value = {"name": OPERATION_NAME, "done": True} + sync_hook.project_id = GCP_PROJECT + credentials = mock.Mock(valid=valid_credentials, token="test-token") + sync_hook.get_credentials.return_value = credentials self.hook.get_sync_hook = mock.AsyncMock(return_value=sync_hook) + session = mock.MagicMock() + mock_client_session.return_value.__aenter__.return_value = session + response = mock.Mock() + response.json = mock.AsyncMock(return_value={"name": OPERATION_NAME, "done": True}) + request = mock.MagicMock() + request.__aenter__.return_value = response + session.get.return_value = request result = await self.hook.get_agent_engine_operation( - project_id=GCP_PROJECT, location=GCP_LOCATION, operation_id=OPERATION_ID, + request_timeout=10, ) - sync_hook.get_agent_engine_operation.assert_called_once_with( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - operation_id=OPERATION_ID, - request_timeout=60.0, + assert credentials.refresh.call_count == expected_refresh_calls + session.get.assert_called_once_with( + f"https://{GCP_LOCATION}-aiplatform.googleapis.com/v1beta1/{OPERATION_NAME}", + headers={"Authorization": "Bearer test-token"}, + timeout=10, ) + response.raise_for_status.assert_called_once_with() + sync_hook.get_agent_engine_operation.assert_not_called() assert result == {"name": OPERATION_NAME, "done": True} @pytest.mark.asyncio From 9d744193af7c64a4d1a5ef804428e0916660abbf Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Thu, 2 Jul 2026 13:56:38 -0300 Subject: [PATCH 31/32] Add api reference links for Vertex AI Agent Engine --- providers/google/docs/operators/cloud/vertex_ai.rst | 3 +++ .../providers/google/cloud/hooks/vertex_ai/agent_engine.py | 7 ++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/providers/google/docs/operators/cloud/vertex_ai.rst b/providers/google/docs/operators/cloud/vertex_ai.rst index f1ba0e652d315..9e9a511e86cac 100644 --- a/providers/google/docs/operators/cloud/vertex_ai.rst +++ b/providers/google/docs/operators/cloud/vertex_ai.rst @@ -29,6 +29,9 @@ request predictions with Vertex AI. Managing Agent Engines ^^^^^^^^^^^^^^^^^^^^^^ +The operators below manage `Vertex AI Agent Engine +`__ resources. + To create a Vertex AI Agent Engine you can use :class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.CreateAgentEngineOperator`. diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index a7ac846985772..a5b25794a26c5 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -64,7 +64,12 @@ def serialize_value(value: Any) -> Any: class AgentEngineHook(GoogleBaseHook): - """Hook for Google Cloud Vertex AI Agent Engine APIs.""" + """ + Hook for Google Cloud Vertex AI Agent Engine APIs. + + Wraps the ``agent_engines`` module of the Vertex AI SDK client: + https://cloud.google.com/python/docs/reference/agentplatform/latest/vertexai._genai.agent_engines.AgentEngines + """ def __init__( self, From 54dbed97dc36ba031927f788d1241a555c3d0df6 Mon Sep 17 00:00:00 2001 From: AlejandroMorgante <62363051+AlejandroMorgante@users.noreply.github.com> Date: Thu, 2 Jul 2026 14:08:18 -0300 Subject: [PATCH 32/32] Use canonical Google docs domain for Agent Engine reference links --- providers/google/docs/operators/cloud/vertex_ai.rst | 2 +- .../providers/google/cloud/hooks/vertex_ai/agent_engine.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/google/docs/operators/cloud/vertex_ai.rst b/providers/google/docs/operators/cloud/vertex_ai.rst index 9e9a511e86cac..f11b8787ef684 100644 --- a/providers/google/docs/operators/cloud/vertex_ai.rst +++ b/providers/google/docs/operators/cloud/vertex_ai.rst @@ -30,7 +30,7 @@ Managing Agent Engines ^^^^^^^^^^^^^^^^^^^^^^ The operators below manage `Vertex AI Agent Engine -`__ resources. +`__ resources. To create a Vertex AI Agent Engine you can use :class:`~airflow.providers.google.cloud.operators.vertex_ai.agent_engine.CreateAgentEngineOperator`. diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py index a5b25794a26c5..f90d53475699d 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/agent_engine.py @@ -68,7 +68,7 @@ class AgentEngineHook(GoogleBaseHook): Hook for Google Cloud Vertex AI Agent Engine APIs. Wraps the ``agent_engines`` module of the Vertex AI SDK client: - https://cloud.google.com/python/docs/reference/agentplatform/latest/vertexai._genai.agent_engines.AgentEngines + https://docs.cloud.google.com/python/docs/reference/agentplatform/latest/vertexai._genai.agent_engines.AgentEngines """ def __init__(