diff --git a/.github/workflows/check-file-contents.yml b/.github/workflows/check-file-contents.yml index 42d820ab473..a36d373c484 100644 --- a/.github/workflows/check-file-contents.yml +++ b/.github/workflows/check-file-contents.yml @@ -108,13 +108,30 @@ jobs: if [ -n "$CHANGED_FILES" ]; then echo "Checking for hardcoded endpoints in: $CHANGED_FILES" - # 1. Identify files containing any googleapis.com URL. + # 1. Identify files containing any googleapis.com URL (candidate set). set +e FILES_WITH_ENDPOINTS=$(grep -lE 'https?://[a-zA-Z0-9.-]+\.googleapis\.com' $CHANGED_FILES) - # 2. From those, identify files that are MISSING the required mTLS version. - if [ -n "$FILES_WITH_ENDPOINTS" ]; then - FILES_MISSING_MTLS=$(grep -L '.mtls.googleapis.com' $FILES_WITH_ENDPOINTS) + # 2. Filter the candidate set: drop files whose only googleapis.com + # references are OAuth 2.0 scope URLs (e.g. + # https://www.googleapis.com/auth/cloud-platform). Those are + # identity strings, not API endpoints — they don't have mTLS + # counterparts and never will. Without this filter, any source + # file that legitimately declares an OAuth scope (very common + # for ADK plugins integrating Google APIs) trips the gate even + # when no real endpoint is hardcoded. + FILES_WITH_REAL_ENDPOINTS="" + for f in $FILES_WITH_ENDPOINTS; do + if grep -E 'https?://[a-zA-Z0-9.-]+\.googleapis\.com' "$f" \ + | grep -vqE 'googleapis\.com/auth/'; then + FILES_WITH_REAL_ENDPOINTS="$FILES_WITH_REAL_ENDPOINTS $f" + fi + done + + # 3. From the filtered set, identify files MISSING the required + # mTLS variant. + if [ -n "$FILES_WITH_REAL_ENDPOINTS" ]; then + FILES_MISSING_MTLS=$(grep -L '.mtls.googleapis.com' $FILES_WITH_REAL_ENDPOINTS) fi set -e diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 90bd50628c3..b38427cd5bb 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -73,6 +73,7 @@ if TYPE_CHECKING: from ..agents.invocation_context import InvocationContext + from ..events.event import Event logger: logging.Logger = logging.getLogger("google_adk." + __name__) tracer = trace.get_tracer( @@ -84,12 +85,57 @@ _SCHEMA_VERSION = "1" _SCHEMA_VERSION_LABEL_KEY = "adk_schema_version" +# ADK 2.0 envelope version. Stamped onto every ADK-enriched row as +# ``attributes.adk.schema_version``. Independent of the BigQuery row +# schema version above — this names the producer's ADK 2.0 attribute +# contract so downstream consumers can gate on it. +_ADK_ENVELOPE_SCHEMA_VERSION = "1" + _HITL_EVENT_MAP = MappingProxyType({ "adk_request_credential": "HITL_CREDENTIAL_REQUEST", "adk_request_confirmation": "HITL_CONFIRMATION_REQUEST", "adk_request_input": "HITL_INPUT_REQUEST", }) +# Reverse of _HITL_EVENT_MAP for the long-running-tool pause_kind +# discriminator. The id→name lookup routes ``adk_request_credential`` +# → ``hitl_credential`` etc.; everything else is ``tool``. +_HITL_PAUSE_KIND_MAP = MappingProxyType({ + "adk_request_credential": "hitl_credential", + "adk_request_confirmation": "hitl_confirmation", + "adk_request_input": "hitl_input", +}) + + +def _derive_scope( + isolation_scope: Optional[str], +) -> Optional[dict[str, str]]: + """Derives ``attributes.adk.scope`` from an Event's isolation_scope. + + Order is fixed: (1) None → null; (2) node-shape (``name@run_id`` or + ``parent/name@run_id``) → ``node_run``; (3) any other non-empty + string → ``function_call`` (model-provided FC IDs like ``call_*`` and + ``toolu_*`` legitimately match here); (4) empty/non-string → ``unknown`` + with a warning. Steps 2 and 3 are intentionally ordered: a bare + ``name@run_id`` must classify as ``node_run`` first, not as + ``function_call`` by fall-through. + """ + if isolation_scope is None: + return None + if not isinstance(isolation_scope, str) or not isolation_scope: + logger.warning( + "Unexpected isolation_scope shape: %r; classifying as 'unknown'", + isolation_scope, + ) + return {"id": str(isolation_scope), "kind": "unknown"} + # Node-shape: last segment contains '@'. The full string may also be + # path-prefixed (e.g. ``wf/A@1/B@2``). + last_segment = isolation_scope.rsplit("/", 1)[-1] + if "@" in last_segment: + return {"id": isolation_scope, "kind": "node_run"} + return {"id": isolation_scope, "kind": "function_call"} + + # Track all living plugin instances so the fork handler can reset # them proactively in the child, before _ensure_started runs. _LIVE_PLUGINS: weakref.WeakSet = weakref.WeakSet() @@ -634,20 +680,35 @@ class BigQueryLoggerConfig: @dataclass class _SpanRecord: - """A single record on the unified span stack. - - Consolidates span, id, ownership, and timing into one object - so all stacks stay in sync by construction. - - Note: The plugin intentionally does NOT attach its spans to the - ambient OTel context (no ``context.attach``). This prevents the - plugin from corrupting the framework's span hierarchy when an - external OTel exporter (e.g. ``opentelemetry-instrumentation-vertexai``) - is active. See https://github.com/google/adk-python/issues/4561. + """A single record on the BQAA plugin's internal span stack. + + Stores the IDs and timing the plugin needs to populate BigQuery + ``span_id`` / ``parent_span_id`` / ``trace_id`` / ``latency_ms`` + columns. Crucially, no OpenTelemetry ``Span`` object is held. + + Background — prior approach and the bug it caused: + The previous implementation created real OTel spans via + ``tracer.start_span(...)`` purely as ID carriers. When the host + application has an OTel exporter configured (notably Agent Engine + with ``GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY=true``), those + plugin-owned spans were exported to Cloud Trace alongside the + framework's real spans — producing a duplicate-span view for + every BQAA-instrumented operation. See haiyuan-eng-google/BQAA-SDK#94. + + The plugin already tracked all parent / child relationships on + this internal stack, so the OTel span object was incidental to + correctness. We now store ``trace_id`` directly on each record + (inherited from the ambient OTel span when present, generated + otherwise) and skip span creation entirely. Cross-system + correlation with Cloud Trace still works via ``trace_id`` + inheritance. + + ``attach_current_span`` (which observes the ambient span without + owning one) is unaffected by this change. """ - span: trace.Span span_id: str + trace_id: str owns_span: bool start_time_ns: int first_token_time: Optional[float] = None @@ -689,17 +750,16 @@ def init_trace(callback_context: CallbackContext) -> None: @staticmethod def get_trace_id(callback_context: CallbackContext) -> Optional[str]: - """Gets the trace ID from the current span or invocation_id.""" + """Gets the trace ID from the current span stack or invocation_id.""" records = _span_records_ctx.get() if records: - current_span = records[-1].span - if current_span.get_span_context().is_valid: - return format(current_span.get_span_context().trace_id, "032x") + return records[-1].trace_id - # Fallback to OTel context - current_span = trace.get_current_span() - if current_span.get_span_context().is_valid: - return format(current_span.get_span_context().trace_id, "032x") + # Fallback to ambient OTel context (e.g. callbacks fired before + # any plugin span was pushed). + ambient_ctx = trace.get_current_span().get_span_context() + if ambient_ctx.is_valid: + return format(ambient_ctx.trace_id, "032x") return callback_context.invocation_id @@ -708,47 +768,48 @@ def push_span( callback_context: CallbackContext, span_name: Optional[str] = "adk-span", ) -> str: - """Starts a new span and pushes it onto the stack. - - The span is created but NOT attached to the ambient OTel context, - so it cannot corrupt the framework's own span hierarchy. The - plugin tracks span_id / parent_span_id internally via its own - contextvar stack. - - If OTel is not configured (returning non-recording spans), a UUID - fallback is generated to ensure span_id and parent_span_id are - populated in BigQuery logs. + """Pushes a BQAA-internal span record onto the stack. + + No OpenTelemetry span is created — see ``_SpanRecord`` for + background. The record carries everything the plugin needs to + populate BigQuery columns: + + * ``span_id`` — newly generated 16-hex string. + * ``trace_id`` — inherited by precedence: + 1. Top of the existing internal stack (keeps every push + within an invocation under one trace_id). + 2. Ambient OTel span when valid (e.g. the framework's Runner + span, or an Agent Engine root span) — keeps BigQuery rows + joinable to Cloud Trace via the shared ``trace_id``. + 3. A fresh 32-hex value (no ambient context, e.g. unit tests + or non-OTel runtimes). + * ``start_time_ns`` — for the eventual ``latency_ms`` on pop. + + ``span_name`` is preserved on the signature for API stability but + is no longer used (no OTel span name is set). """ + del span_name # No-op: kept for API stability; no OTel span is created. TraceManager.init_trace(callback_context) - # Create the span without attaching it to the ambient context. - # This avoids re-parenting framework spans like ``call_llm`` - # or ``execute_tool``. See #4561. - # - # If the internal stack already has a span, create the new span - # as a child so it shares the same trace_id. Without this, each - # ``start_span`` would be an independent root with its own - # trace_id — causing trace_id fracture (see #4645). records = TraceManager._get_records() - parent_ctx = None - if records and records[-1].span.get_span_context().is_valid: - parent_ctx = trace.set_span_in_context(records[-1].span) - span = tracer.start_span(span_name, context=parent_ctx) - - if span.get_span_context().is_valid: - span_id_str = format(span.get_span_context().span_id, "016x") + if records: + trace_id = records[-1].trace_id else: - span_id_str = uuid.uuid4().hex + ambient_ctx = trace.get_current_span().get_span_context() + if ambient_ctx.is_valid: + trace_id = format(ambient_ctx.trace_id, "032x") + else: + trace_id = uuid.uuid4().hex # 32 hex chars + + span_id_str = uuid.uuid4().hex[:16] record = _SpanRecord( - span=span, span_id=span_id_str, + trace_id=trace_id, owns_span=True, start_time_ns=time.time_ns(), ) - - new_records = list(records) + [record] - _span_records_ctx.set(new_records) + _span_records_ctx.set(list(records) + [record]) return span_id_str @@ -756,30 +817,31 @@ def push_span( def attach_current_span( callback_context: CallbackContext, ) -> str: - """Records the current OTel span on the stack without owning it. + """Records the ambient OTel span's IDs on the stack without owning it. - The span is NOT re-attached to the ambient context; it is only - tracked internally for span_id / parent_span_id resolution. + No OTel span is created or attached. This path captures the + ambient span's ``trace_id`` / ``span_id`` so plugin-emitted + BigQuery rows correlate with whatever Cloud Trace / external + exporter the host is already running. """ TraceManager.init_trace(callback_context) - span = trace.get_current_span() - - if span.get_span_context().is_valid: - span_id_str = format(span.get_span_context().span_id, "016x") + ambient_ctx = trace.get_current_span().get_span_context() + if ambient_ctx.is_valid: + span_id_str = format(ambient_ctx.span_id, "016x") + trace_id = format(ambient_ctx.trace_id, "032x") else: - span_id_str = uuid.uuid4().hex + span_id_str = uuid.uuid4().hex[:16] + trace_id = uuid.uuid4().hex record = _SpanRecord( - span=span, span_id=span_id_str, + trace_id=trace_id, owns_span=False, start_time_ns=time.time_ns(), ) - records = TraceManager._get_records() - new_records = list(records) + [record] - _span_records_ctx.set(new_records) + _span_records_ctx.set(list(records) + [record]) return span_id_str @@ -828,10 +890,10 @@ def ensure_invocation_span( @staticmethod def pop_span() -> tuple[Optional[str], Optional[int]]: - """Ends the current span and pops it from the stack. + """Pops the top span record from the internal stack. - No ambient OTel context is detached because we never attached - one in the first place (see ``push_span``). + Returns ``(span_id, duration_ms)``. No OTel span is ended + because the plugin no longer creates one (see ``_SpanRecord``). """ records = _span_records_ctx.get() if not records: @@ -841,29 +903,13 @@ def pop_span() -> tuple[Optional[str], Optional[int]]: record = new_records.pop() _span_records_ctx.set(new_records) - # Calculate duration - duration_ms = None - otel_start = getattr(record.span, "start_time", None) - if isinstance(otel_start, (int, float)) and otel_start: - duration_ms = int((time.time_ns() - otel_start) / 1_000_000) - else: - duration_ms = int((time.time_ns() - record.start_time_ns) / 1_000_000) - - if record.owns_span: - record.span.end() - + duration_ms = int((time.time_ns() - record.start_time_ns) / 1_000_000) return record.span_id, duration_ms @staticmethod def clear_stack() -> None: """Clears all span records. Safety net for cross-invocation cleanup.""" - records = _span_records_ctx.get() - if records: - # End any owned spans to avoid OTel resource leaks. - for record in reversed(records): - if record.owns_span: - record.span.end() - _span_records_ctx.set([]) + _span_records_ctx.set([]) @staticmethod def get_current_span_and_parent() -> tuple[Optional[str], Optional[str]]: @@ -894,19 +940,11 @@ def get_root_agent_name() -> Optional[str]: @staticmethod def get_start_time(span_id: str) -> Optional[float]: - """Gets start time of a span by ID.""" + """Gets start time of a span by ID (seconds since epoch).""" records = _span_records_ctx.get() if records: for record in reversed(records): if record.span_id == span_id: - # Try OTel span start_time first - otel_start = getattr(record.span, "start_time", None) - if ( - record.span.get_span_context().is_valid - and isinstance(otel_start, (int, float)) - and otel_start - ): - return otel_start / 1_000_000_000.0 return record.start_time_ns / 1_000_000_000.0 return None @@ -1171,7 +1209,21 @@ async def requests_iter(): yield req async def perform_write(): - responses = await self.write_client.append_rows(requests_iter()) + # The AppendRows streaming RPC does not auto-populate the + # request-routing header, so writes to any region other than + # the US multiregion fail with a "session not found" / + # stream-not-found error. Set the routing header explicitly + # (same as google.cloud.bigquery_storage_v1.writer) so the + # request reaches the region that owns the write stream. + responses = await self.write_client.append_rows( + requests_iter(), + metadata=( + ( + "x-goog-request-params", + f"write_stream={self.write_stream}", + ), + ), + ) async for response in responses: error = getattr(response, "error", None) error_code = getattr(error, "code", None) @@ -1861,6 +1913,11 @@ def _get_events_schema() -> list[bigquery.SchemaField]: "JSON_QUERY(content, '$.result') AS tool_result", "JSON_VALUE(content, '$.tool_origin') AS tool_origin", "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + # Long-running pair keys: null for ordinary completions, + # populated on the user-message resume path so typed views can + # do the TOOL_PAUSED ↔ TOOL_COMPLETED join end-to-end. + "JSON_VALUE(attributes, '$.adk.pause_kind') AS pause_kind", + "JSON_VALUE(attributes, '$.adk.function_call_id') AS function_call_id", ], "TOOL_ERROR": [ "JSON_VALUE(content, '$.tool') AS tool_name", @@ -1922,6 +1979,52 @@ def _get_events_schema() -> list[bigquery.SchemaField]: " '$.source_event_branch') AS source_event_branch" ), ], + "AGENT_TRANSFER": [ + "JSON_VALUE(content, '$.from_agent') AS from_agent", + "JSON_VALUE(content, '$.to_agent') AS to_agent", + "JSON_VALUE(attributes, '$.adk.source_event_id') AS source_event_id", + ], + "EVENT_COMPACTION": [ + ( + "CAST(JSON_VALUE(content," + " '$.start_timestamp') AS FLOAT64) AS start_seconds" + ), + ( + "CAST(JSON_VALUE(content," + " '$.end_timestamp') AS FLOAT64) AS end_seconds" + ), + ( + "TIMESTAMP_MICROS(CAST(CAST(JSON_VALUE(content," + " '$.start_timestamp') AS FLOAT64) * 1000000 AS INT64))" + " AS window_start" + ), + ( + "TIMESTAMP_MICROS(CAST(CAST(JSON_VALUE(content," + " '$.end_timestamp') AS FLOAT64) * 1000000 AS INT64))" + " AS window_end" + ), + "JSON_QUERY(content, '$.compacted_content') AS compacted_content", + ], + "AGENT_STATE_CHECKPOINT": [ + "JSON_QUERY(content, '$.agent_state') AS agent_state", + # Presence discriminator. JSON_QUERY on an explicit JSON null + # returns JSON null (not SQL NULL), so consumers must check + # JSON_TYPE: SQL NULL = key absent, 'null' = explicit JSON + # null (the {agent_state: null, end_of_agent: true} shape), + # anything else = a real state object. + "JSON_TYPE(JSON_QUERY(content, '$.agent_state')) AS agent_state_type", + ( + "SAFE_CAST(JSON_VALUE(content," + " '$.end_of_agent') AS BOOL) AS end_of_agent" + ), + "JSON_VALUE(attributes, '$.adk.source_event_id') AS source_event_id", + ], + "TOOL_PAUSED": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + "JSON_VALUE(attributes, '$.adk.pause_kind') AS pause_kind", + "JSON_VALUE(attributes, '$.adk.function_call_id') AS function_call_id", + ], } _VIEW_SQL_TEMPLATE = """\ @@ -1962,6 +2065,21 @@ class EventData: error_message: Optional[str] = None extra_attributes: dict[str, Any] = field(default_factory=dict) trace_id_override: Optional[str] = None + # ADK 2.0 envelope: callbacks that hold the source Event pass it here + # so ``_log_event`` can stamp ``attributes.adk.{source_event_id, node, + # branch, scope, ...}``. Leave None for rows that don't originate from + # an Event — the envelope helper omits those keys rather than + # synthesizing fake identity. Because the + # surrounding column is BigQuery JSON, an omitted key resolves to SQL + # NULL via ``JSON_VALUE(attributes, '$.adk.')``, so consumer + # gating with ``... IS NOT NULL`` works without explicit JSON nulls. + source_event: Optional["Event"] = None + # Producer-supplied extras that belong INSIDE ``attributes.adk`` (not + # at the top level of ``attributes``). C7's pair keys + # (``pause_kind`` / ``function_call_id``) ride here so consumer SQL + # like ``JSON_VALUE(attributes, '$.adk.function_call_id')`` lands at + # the right JSON path. + adk_extras: dict[str, Any] = field(default_factory=dict) class BigQueryAgentAnalyticsPlugin(BasePlugin): @@ -2748,6 +2866,113 @@ def _extract_latency( latency_json["time_to_first_token_ms"] = event_data.time_to_first_token_ms return latency_json or None + def _build_adk_envelope( + self, + callback_context: CallbackContext, + source_event: Optional["Event"], + ) -> dict[str, Any]: + """Builds the ``attributes.adk`` envelope. + + A1 / A2 (``schema_version``, ``app_name``) stamp on every ADK-enriched + row regardless of origin. A3 / C1 / C2 / C3 (``source_event_id``, + ``node``, ``branch``, ``scope``) and C8 (``route``, + ``render_ui_widgets``, ``rewind_before_invocation_id``) only stamp + when a source Event is provided — callback-only rows **omit** those + keys from the envelope rather than synthesizing fake identity. Since + the surrounding column is BigQuery JSON, an omitted key resolves to + SQL NULL via ``JSON_VALUE(attributes, '$.adk.')``; consumers + using ``JSON_VALUE(...) IS NOT NULL`` to gate on Event-originating + rows therefore work correctly without the producer writing explicit + JSON nulls. + """ + adk: dict[str, Any] = { + "schema_version": _ADK_ENVELOPE_SCHEMA_VERSION, + } + try: + adk["app_name"] = callback_context._invocation_context.session.app_name + except Exception: + adk["app_name"] = None + + if source_event is None: + return adk + + # Every getattr below is defensive: source_event is "anything the + # caller hands us", which in test suites can be a Mock. Best-effort + # enrichment means "leave null on missing attrs", never crash the + # row. + try: + source_event_id = getattr(source_event, "id", None) + if source_event_id: + adk["source_event_id"] = source_event_id # A3 + except Exception: + pass + + # C1: node = {path, run_id, parent_path}. NodeInfo.path defaults to + # the empty string in current ADK (events/event.py:45); preserve it + # verbatim and emit parent_path = null rather than synthesizing a + # fake workflow node from an empty path. + try: + node_info = getattr(source_event, "node_info", None) + if node_info is not None and hasattr(node_info, "path"): + path = getattr(node_info, "path", "") or "" + run_id = getattr(node_info, "run_id", None) + parent_path = None + if path and "/" in path: + parent_path = path.rsplit("/", 1)[0] + adk["node"] = { + "path": path, + "run_id": run_id, + "parent_path": parent_path, + } + except Exception: + pass + + # C2: branch — absent stays JSON null (no sentinel string). + try: + if hasattr(source_event, "branch"): + adk["branch"] = source_event.branch + except Exception: + pass + + # Scope shape derivation. Order matters: + # node-shape patterns must be checked before falling through to + # function_call so bare ``name@run_id`` doesn't misclassify. + try: + if hasattr(source_event, "isolation_scope"): + adk["scope"] = _derive_scope(source_event.isolation_scope) + except Exception: + pass + + # Raw EventActions mirror (flat under attributes.adk). + # Stamp only when actually set so JSON doesn't bloat with nulls. + try: + actions = getattr(source_event, "actions", None) + except Exception: + actions = None + if actions is not None: + try: + route = getattr(actions, "route", None) + if route is not None: + adk["route"] = route + except Exception: + pass + try: + widgets = getattr(actions, "render_ui_widgets", None) + if widgets is not None: + adk["render_ui_widgets"] = [ + w.model_dump() if hasattr(w, "model_dump") else w for w in widgets + ] + except Exception: + pass + try: + rewind = getattr(actions, "rewind_before_invocation_id", None) + if rewind is not None: + adk["rewind_before_invocation_id"] = rewind + except Exception: + pass + + return adk + def _enrich_attributes( self, event_data: EventData, @@ -2757,12 +2982,23 @@ def _enrich_attributes( Reads ``model``, ``model_version``, and ``usage_metadata`` from *event_data*, copies ``extra_attributes``, then adds session metadata - and custom tags. + and custom tags. Also stamps the ``adk`` envelope. Returns: A new dict ready for JSON serialization into the attributes column. """ attrs: dict[str, Any] = dict(event_data.extra_attributes) + adk_envelope = self._build_adk_envelope( + callback_context, event_data.source_event + ) + # Merge producer-supplied adk_extras (long-running pair keys etc.) + # INTO the adk envelope so consumer SQL on + # ``$.adk.pause_kind`` / ``$.adk.function_call_id`` resolves. + # adk_envelope wins on key conflict — producer-derived envelope + # is the source of truth for identity fields like source_event_id. + for k, v in event_data.adk_extras.items(): + adk_envelope.setdefault(k, v) + attrs["adk"] = adk_envelope attrs["root_agent_name"] = TraceManager.get_root_agent_name() if event_data.model: @@ -2880,10 +3116,31 @@ async def _log_event( except (TypeError, ValueError): attributes_json = json.dumps(attributes, default=str) + # InvocationContext.agent is Optional and is None for invocations + # driven by a workflow engine (deterministic nodes). Derive the + # row's agent label from the underlying invocation context rather + # than ReadonlyContext.agent_name, so the fallback behavior does + # not depend on whether core raises AttributeError, returns a + # string sentinel like "unknown", or returns None for the no-agent + # case. Chain: + # agent present → agent.name + # no agent + source Event → Event.author (the emitting node) + # callback-only row → null + agent_obj = getattr(callback_context._invocation_context, "agent", None) + agent_name = getattr(agent_obj, "name", None) if agent_obj else None + if agent_name is None: + agent_name = getattr(event_data.source_event, "author", None) + logger.debug( + "InvocationContext.agent is unavailable; using source Event" + " author %r as the row's agent for event_type=%s.", + agent_name, + event_type, + ) + row = { "timestamp": timestamp, "event_type": event_type, - "agent": callback_context.agent_name, + "agent": agent_name, "user_id": callback_context.user_id, "session_id": callback_context.session.id, "invocation_id": callback_context.invocation_id, @@ -2915,9 +3172,14 @@ async def on_user_message_callback( ) -> None: """Parity with V1: Logs USER_MESSAGE_RECEIVED event. - Also detects HITL completion responses (user-sent - ``FunctionResponse`` parts with ``adk_request_*`` names) and emits - dedicated ``HITL_*_COMPLETED`` events. + Also detects: + * HITL completion responses (user-sent ``FunctionResponse`` parts + with ``adk_request_*`` names) → ``HITL_*_COMPLETED``. + * Non-HITL ``FunctionResponse`` parts from a user message → these + are the long-running tool completions for tools that paused via + ``TOOL_PAUSED``. Emitted as ``TOOL_COMPLETED`` with + ``pause_kind = 'tool'`` and ``function_call_id`` so the customer + can join the pair from BigQuery. Args: invocation_context: The context of the current invocation. @@ -2931,26 +3193,56 @@ async def on_user_message_callback( raw_content=user_message, ) - # Detect HITL completion responses in the user message. + # Detect completion responses in the user message. if user_message and user_message.parts: for part in user_message.parts: - if part.function_response: - hitl_event = _HITL_EVENT_MAP.get(part.function_response.name) - if hitl_event: - resp_truncated, is_truncated = _recursive_smart_truncate( - part.function_response.response or {}, - self.config.max_content_length, - ) - content_dict = { - "tool": part.function_response.name, - "result": resp_truncated, - } - await self._log_event( - hitl_event + "_COMPLETED", - callback_ctx, - raw_content=content_dict, - is_truncated=is_truncated, + if not part.function_response: + continue + hitl_event = _HITL_EVENT_MAP.get(part.function_response.name) + resp_truncated, is_truncated = _recursive_smart_truncate( + part.function_response.response or {}, + self.config.max_content_length, + ) + content_dict = { + "tool": part.function_response.name, + "result": resp_truncated, + } + if hitl_event: + # HITL completions stay on the HITL_*_COMPLETED stream — they + # MUST NOT also emit TOOL_COMPLETED. + await self._log_event( + hitl_event + "_COMPLETED", + callback_ctx, + raw_content=content_dict, + is_truncated=is_truncated, + ) + else: + # Non-HITL function_response arriving via a user message is + # by construction a long-running tool completion: regular + # tool calls complete inside the agent run via + # after_tool_callback, so a function_response inside a user + # message is the resume side of a previously-paused tool. + # Stamp the pair keys; pause_orphan / registry semantics + # are intentionally deferred. + if not part.function_response.id: + logger.debug( + "User-message function_response for tool %s has no id;" + " the resulting TOOL_COMPLETED row cannot pair with a" + " TOOL_PAUSED row.", + part.function_response.name, ) + await self._log_event( + "TOOL_COMPLETED", + callback_ctx, + raw_content=content_dict, + is_truncated=is_truncated, + event_data=EventData( + adk_extras={ + "pause_kind": "tool", + "function_call_id": part.function_response.id, + }, + ), + ) @_safe_callback async def on_event_callback( @@ -2993,11 +3285,85 @@ async def on_event_callback( "STATE_DELTA", callback_ctx, event_data=EventData( - extra_attributes={"state_delta": dict(event.actions.state_delta)} + source_event=event, + extra_attributes={"state_delta": dict(event.actions.state_delta)}, ), ) - # --- HITL event logging --- + # --- AGENT_TRANSFER --- + # actions.transfer_to_agent stores the *target* agent only + # (events/event_actions.py:75); from_agent is pinned to event.author + # by contract. Never fabricate authors on non-Event paths. + if event.actions.transfer_to_agent: + await self._log_event( + "AGENT_TRANSFER", + callback_ctx, + raw_content={ + "from_agent": event.author, + "to_agent": event.actions.transfer_to_agent, + }, + event_data=EventData(source_event=event), + ) + + # --- EVENT_COMPACTION --- + # EventCompaction.start_timestamp / end_timestamp are float epoch + # seconds. Preserve fractional precision here; consumer view + # conversion is deferred. + compaction = event.actions.compaction + if compaction is not None: + compacted_content, compaction_truncated = self._format_content_safely( + compaction.compacted_content + ) + await self._log_event( + "EVENT_COMPACTION", + callback_ctx, + raw_content={ + "start_timestamp": compaction.start_timestamp, + "end_timestamp": compaction.end_timestamp, + "compacted_content": compacted_content, + }, + is_truncated=compaction_truncated, + event_data=EventData(source_event=event), + ) + + # --- AGENT_STATE_CHECKPOINT --- + # Fires when *either* agent_state is set or end_of_agent is True; + # supports {agent_state: None, end_of_agent: True} payloads. + # Inline payload only — oversized-state GCS offload deferred. + if ( + event.actions.agent_state is not None + or event.actions.end_of_agent is True + ): + agent_state_dict, agent_state_truncated = ( + _recursive_smart_truncate( + event.actions.agent_state, + self.config.max_content_length, + ) + if event.actions.agent_state is not None + else (None, False) + ) + await self._log_event( + "AGENT_STATE_CHECKPOINT", + callback_ctx, + raw_content={ + "agent_state": agent_state_dict, + "end_of_agent": bool(event.actions.end_of_agent), + }, + is_truncated=agent_state_truncated, + event_data=EventData(source_event=event), + ) + + # --- HITL + TOOL_PAUSED (pair-key emit) + per-part + # iteration over event.content.parts --- + # TOOL_PAUSED fires per long_running_tool_id; pause_kind is derived + # via the id→name lookup against _HITL_PAUSE_KIND_MAP, so a HITL + # long-running call carries pause_kind = 'hitl_*' and a regular + # long-running tool carries pause_kind = 'tool'. function_call_id + # joins to the downstream TOOL_COMPLETED via the user message path. + # Use getattr so the existing Mock-based HITL test fixtures still + # work — they construct events without setting long_running_tool_ids. + long_running_ids = set(getattr(event, "long_running_tool_ids", None) or ()) + paused_ids_emitted: set[str] = set() if event.content and event.content.parts: for part in event.content.parts: # Detect HITL function calls (request events). @@ -3017,8 +3383,39 @@ async def on_event_callback( callback_ctx, raw_content=content_dict, is_truncated=is_truncated, + event_data=EventData(source_event=event), + ) + # Per-id TOOL_PAUSED emit. pause_kind derives from the + # function_call NAME — looking it up against the id value + # would misclassify every HITL pause as 'tool'. + if part.function_call.id in long_running_ids: + paused_ids_emitted.add(part.function_call.id) + pause_kind = _HITL_PAUSE_KIND_MAP.get( + part.function_call.name, "tool" + ) + args_truncated, is_truncated = _recursive_smart_truncate( + part.function_call.args or {}, + self.config.max_content_length, ) - # Detect HITL function responses (completion events). + await self._log_event( + "TOOL_PAUSED", + callback_ctx, + raw_content={ + "tool": part.function_call.name, + "args": args_truncated, + }, + is_truncated=is_truncated, + event_data=EventData( + source_event=event, + adk_extras={ + "pause_kind": pause_kind, + "function_call_id": part.function_call.id, + }, + ), + ) + # Detect HITL function responses (completion events). HITL + # function responses route ONLY here, never to TOOL_COMPLETED + # (verified by this file's HITL test suite). if part.function_response: hitl_event = _HITL_EVENT_MAP.get(part.function_response.name) if hitl_event: @@ -3035,8 +3432,33 @@ async def on_event_callback( callback_ctx, raw_content=content_dict, is_truncated=is_truncated, + event_data=EventData(source_event=event), ) + # Fallback: a long_running_tool_id with no matching function_call + # part (possible after after_model_callback content rewrites) still + # gets a pairable TOOL_PAUSED row. Without the name we cannot derive + # an HITL pause_kind, so default to 'tool' and warn. + for orphan_pause_id in long_running_ids - paused_ids_emitted: + logger.warning( + "long_running_tool_id %s has no matching function_call part in" + " event %s; emitting TOOL_PAUSED with pause_kind='tool'.", + orphan_pause_id, + getattr(event, "id", None), + ) + await self._log_event( + "TOOL_PAUSED", + callback_ctx, + raw_content={"tool": None, "args": None}, + event_data=EventData( + source_event=event, + adk_extras={ + "pause_kind": "tool", + "function_call_id": orphan_pause_id, + }, + ), + ) + # --- A2A interaction logging --- # RemoteA2aAgent attaches cross-reference metadata to events: # a2a:task_id, a2a:context_id — correlation keys @@ -3070,6 +3492,7 @@ async def on_event_callback( raw_content=content_dict, is_truncated=is_truncated or content_truncated, event_data=EventData( + source_event=event, extra_attributes={ "a2a_metadata": a2a_truncated, }, @@ -3106,12 +3529,17 @@ async def on_event_callback( role=event.content.role, parts=visible_parts ) formatted, truncated = self._format_content_safely(visible_content) + # source_event=event carries the ADK envelope (A3 / node / + # branch / scope). The flat ``source_event_*`` extras are + # retained for backward compat with existing AGENT_RESPONSE + # consumers; the canonical keys are under ``attributes.adk.*``. await self._log_event( "AGENT_RESPONSE", callback_ctx, raw_content={"response": formatted}, is_truncated=truncated, event_data=EventData( + source_event=event, extra_attributes={ "source_event_id": event.id, "source_event_author": event.author, @@ -3122,24 +3550,6 @@ async def on_event_callback( return None - async def on_state_change_callback( - self, - *, - callback_context: CallbackContext, - state_delta: dict[str, Any], - ) -> None: - """Deprecated: use on_event_callback instead. - - This method is retained for API compatibility but is never invoked - by the framework (not in BasePlugin, PluginManager, or Runner). - State deltas are now captured via on_event_callback. - """ - logger.warning( - "on_state_change_callback is deprecated and never called by" - " the framework. State deltas are captured via" - " on_event_callback." - ) - @_safe_callback async def before_run_callback( self, *, invocation_context: "InvocationContext" diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index c6d47539d14..b20aaabad90 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -635,6 +635,55 @@ async def test_event_denylist( await asyncio.sleep(0.01) mock_write_client.append_rows.assert_called_once() + @pytest.mark.asyncio + async def test_append_rows_sets_regional_routing_header( + self, + mock_write_client, + callback_context, + mock_auth_default, + mock_bq_client, + mock_to_arrow_schema, + dummy_arrow_schema, + mock_asyncio_to_thread, + ): + """Regression test for cross-region writes (issue #262). + + The Storage Write API streaming AppendRows RPC does not + auto-populate the request-routing header, so writes to a dataset + outside the US multiregion (e.g. northamerica-northeast1) fail with + a "session not found" / stream-not-found error unless the header is + set explicitly. Assert the header is passed to append_rows so the + request reaches the region that owns the write stream. + """ + _ = mock_auth_default + _ = mock_bq_client + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig() + async with managed_plugin( + PROJECT_ID, + DATASET_ID, + table_id=TABLE_ID, + config=config, + location="northamerica-northeast1", + ) as plugin: + await plugin._ensure_started() + mock_write_client.append_rows.reset_mock() + llm_request = llm_request_lib.LlmRequest( + model="gemini-pro", + contents=[types.Content(parts=[types.Part(text="Prompt")])], + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(callback_context) + await plugin.before_model_callback( + callback_context=callback_context, llm_request=llm_request + ) + await asyncio.sleep(0.01) # Allow background task to run + mock_write_client.append_rows.assert_called_once() + metadata = mock_write_client.append_rows.call_args.kwargs.get("metadata") + assert metadata is not None, "append_rows must receive routing metadata" + assert ( + "x-goog-request-params", + f"write_stream={DEFAULT_STREAM_NAME}", + ) in tuple(metadata) + @pytest.mark.asyncio async def test_content_formatter( self, @@ -2225,53 +2274,55 @@ class LocalIncident: assert content_json["result"]["kpi_missed"][0]["kpi"] == "latency" @pytest.mark.asyncio - async def test_otel_integration( + async def test_push_pop_does_not_call_tracer_start_span( self, callback_context, ): - """Verifies OpenTelemetry integration in TraceManager.""" - # Mock the tracer and span + """Regression guard for the duplicate-Cloud-Trace bug (issue #94). + + The plugin must NOT call ``tracer.start_span(...)`` from + ``push_span`` / ``pop_span``. Any owned OTel span goes through + the globally configured exporter (e.g. Cloud Trace via Agent + Engine telemetry) and surfaces as a duplicate span next to the + framework's real one. The plugin's internal stack is sufficient + for ``span_id`` / ``parent_span_id`` / ``trace_id`` resolution + without creating an exportable span. + """ mock_tracer = mock.Mock() - mock_span = mock.Mock() - mock_context = mock.Mock() - # Setup mock IDs (128-bit trace_id, 64-bit span_id) - trace_id_int = 0x12345678123456781234567812345678 - span_id_int = 0x1234567812345678 - mock_context.trace_id = trace_id_int - mock_context.span_id = span_id_int - mock_context.is_valid = True - mock_span.get_span_context.return_value = mock_context - mock_span.start_time = 1234567890000000000 # Mock start time in ns - mock_tracer.start_span.return_value = mock_span - # Patch the global tracer in the plugin module with mock.patch( - "google.adk.plugins.bigquery_agent_analytics_plugin.tracer", mock_tracer + "google.adk.plugins.bigquery_agent_analytics_plugin.tracer", + mock_tracer, ): - # Test push_span span_id = bigquery_agent_analytics_plugin.TraceManager.push_span( callback_context, "test_span" ) - mock_tracer.start_span.assert_called_with("test_span", context=None) - assert span_id == format(span_id_int, "016x") - # Test get_trace_id - # We need to mock trace.get_current_span() to return our mock span - # because push_span calls trace.attach(), which affects the global context - with mock.patch( - "opentelemetry.trace.get_current_span", return_value=mock_span - ): - trace_id = bigquery_agent_analytics_plugin.TraceManager.get_trace_id( - callback_context - ) - assert trace_id == format(trace_id_int, "032x") - # Test pop_span - # pop_span calls span.end() - bigquery_agent_analytics_plugin.TraceManager.pop_span() - mock_span.end.assert_called_once() + assert isinstance(span_id, str) and len(span_id) == 16 + + trace_id = bigquery_agent_analytics_plugin.TraceManager.get_trace_id( + callback_context + ) + assert isinstance(trace_id, str) and len(trace_id) == 32 + + popped_span_id, _duration_ms = ( + bigquery_agent_analytics_plugin.TraceManager.pop_span() + ) + assert popped_span_id == span_id + + mock_tracer.start_span.assert_not_called() @pytest.mark.asyncio - async def test_otel_integration_real_provider(self, callback_context): - """Verifies TraceManager with a real OpenTelemetry TracerProvider.""" - # Setup OTEL with in-memory exporter + async def test_push_pop_does_not_export_spans_through_real_provider( + self, callback_context + ): + """End-to-end regression guard against #94 with a real OTel + provider + in-memory exporter. + + Wires an ``InMemorySpanExporter`` to a real ``TracerProvider``, + drives a push/pop cycle through ``TraceManager``, and asserts + that **zero** spans were exported. Pre-fix behavior was to + export one span per push/pop pair — visible to Cloud Trace as + duplicate spans alongside the framework's real ones. + """ # pylint: disable=g-import-not-at-top from opentelemetry.sdk import trace as trace_sdk from opentelemetry.sdk.trace import export as trace_export @@ -2280,36 +2331,185 @@ async def test_otel_integration_real_provider(self, callback_context): # pylint: enable=g-import-not-at-top provider = trace_sdk.TracerProvider() exporter = in_memory_span_exporter.InMemorySpanExporter() - processor = trace_export.SimpleSpanProcessor(exporter) - provider.add_span_processor(processor) - tracer = provider.get_tracer("test_tracer") - # Patch the global tracer in the plugin module + provider.add_span_processor(trace_export.SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test_tracer") + with mock.patch( - "google.adk.plugins.bigquery_agent_analytics_plugin.tracer", tracer + "google.adk.plugins.bigquery_agent_analytics_plugin.tracer", + real_tracer, ): - # 1. Start a span span_id = bigquery_agent_analytics_plugin.TraceManager.push_span( callback_context, "test_span" ) - # Verify a span was started but not ended - current_spans = exporter.get_finished_spans() - assert not current_spans - # Verify we can retrieve the trace ID + assert exporter.get_finished_spans() == () + trace_id = bigquery_agent_analytics_plugin.TraceManager.get_trace_id( callback_context ) - assert trace_id is not None - # 2. End the span + assert trace_id is not None and len(trace_id) == 32 + popped_span_id, _ = ( bigquery_agent_analytics_plugin.TraceManager.pop_span() ) assert popped_span_id == span_id - # Verify span is now finished and exported - finished_spans = exporter.get_finished_spans() - assert len(finished_spans) == 1 - assert finished_spans[0].name == "test_span" - assert format(finished_spans[0].context.span_id, "016x") == span_id - assert format(finished_spans[0].context.trace_id, "032x") == trace_id + + assert exporter.get_finished_spans() == (), ( + "Plugin must not export OTel spans; any owned span would" + " surface as a duplicate in Cloud Trace alongside the" + " framework's real spans (issue #94)." + ) + + provider.shutdown() + + @pytest.mark.asyncio + async def test_push_span_inherits_ambient_trace_id(self, callback_context): + """When the host has an ambient OTel span (e.g. Agent Engine's + Runner span), the plugin's ``trace_id`` MUST inherit from it so + BigQuery rows correlate with the host's Cloud Trace entries via + a shared ``trace_id``. + """ + # pylint: disable=g-import-not-at-top + from opentelemetry import trace as otel_trace + from opentelemetry.sdk import trace as trace_sdk + + # pylint: enable=g-import-not-at-top + provider = trace_sdk.TracerProvider() + host_tracer = provider.get_tracer("host_tracer") + + # Clear any state on the plugin's contextvar stack. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + with host_tracer.start_as_current_span("ambient-host-span") as host_span: + expected_trace_id = format(host_span.get_span_context().trace_id, "032x") + + # Plugin pushes its first internal span inside the ambient span. + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "bqaa-span" + ) + + plugin_trace_id = ( + bigquery_agent_analytics_plugin.TraceManager.get_trace_id( + callback_context + ) + ) + assert plugin_trace_id == expected_trace_id, ( + "Plugin must inherit ambient trace_id so BigQuery rows join" + " to Cloud Trace via the same trace_id" + ) + + # Nested plugin push also stays under the ambient trace_id. + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "bqaa-nested" + ) + assert ( + bigquery_agent_analytics_plugin.TraceManager.get_trace_id( + callback_context + ) + == expected_trace_id + ) + + bigquery_agent_analytics_plugin.TraceManager.clear_stack() + provider.shutdown() + del otel_trace # unused; imported for symmetry with provider setup + + @pytest.mark.asyncio + async def test_llm_request_response_share_span_id_contract( + self, callback_context + ): + """Lifecycle contract: ``LLM_REQUEST`` and ``LLM_RESPONSE`` for the + same model call share one ``span_id`` and one ``trace_id``. + + Models the structural pattern the real callbacks use: + * ``before_model_callback`` calls ``push_span(...)`` and writes + ``LLM_REQUEST`` with the returned ``span_id``. + * ``after_model_callback`` calls ``get_current_span_id()`` / + ``pop_span()`` and writes ``LLM_RESPONSE`` with the same + ``span_id``. + + A future change must not split this pair onto two different + ``span_id``s — that would break the documented BigQuery query + shape and the BQAA join contract. + """ + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + TM = bigquery_agent_analytics_plugin.TraceManager + + # before_model_callback path. + pushed_span_id = TM.push_span(callback_context, "llm_request") + request_trace_id = TM.get_trace_id(callback_context) + + # after_model_callback (final chunk) path. + response_top_of_stack = TM.get_current_span_id() + popped_span_id, _duration_ms = TM.pop_span() + response_trace_id = TM.get_trace_id(callback_context) + + assert response_top_of_stack == pushed_span_id + assert popped_span_id == pushed_span_id + # trace_id resolved on the response side may have to fall back + # past the now-empty stack — but if it does resolve, it must + # match what the request observed. An empty-stack fallback to + # invocation_id is acceptable here; what we are guarding against + # is the *pair* drifting onto two structurally different ids. + if response_trace_id is not None and len(response_trace_id) == 32: + assert response_trace_id == request_trace_id + + @pytest.mark.asyncio + async def test_tool_starting_completed_share_span_id_contract( + self, callback_context + ): + """Lifecycle contract: ``TOOL_STARTING`` and ``TOOL_COMPLETED`` for + the same tool call share one ``span_id``. + + Same shape as the LLM pair above — push on before, pop on after, + same id on both sides. + """ + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + TM = bigquery_agent_analytics_plugin.TraceManager + + # before_tool_callback path. + pushed_span_id = TM.push_span(callback_context, "tool") + starting_trace_id = TM.get_trace_id(callback_context) + + # after_tool_callback path. + popped_span_id, _duration_ms = TM.pop_span() + + assert popped_span_id == pushed_span_id + assert isinstance(starting_trace_id, str) and len(starting_trace_id) == 32 + + @pytest.mark.asyncio + async def test_streaming_llm_response_shares_span_id_until_final_contract( + self, callback_context + ): + """Streaming-response contract. + + On a streaming LLM call, ``after_model_callback`` is fired once + per partial chunk *plus* once for the final chunk. Partial fires + do NOT pop the span (see ``after_model_callback:3354-3363``) — + they only read ``get_current_span_id()`` and record first-token + timing. Only the final fire calls ``pop_span()``. + + All resulting ``LLM_RESPONSE`` rows therefore share one + ``span_id`` (the same as the paired ``LLM_REQUEST``). A future + change must not "dedupe" the partial rows by switching to a fresh + span id per chunk — those rows are real and intentional. + """ + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + TM = bigquery_agent_analytics_plugin.TraceManager + + pushed_span_id = TM.push_span(callback_context, "llm_request") + + # Simulate three partial chunks: each callback observes the same + # span_id at top of stack and does NOT pop. + for _ in range(3): + assert TM.get_current_span_id() == pushed_span_id + + # Final chunk: pop_span returns the same id and a populated + # latency. + popped_span_id, duration_ms = TM.pop_span() + assert popped_span_id == pushed_span_id + assert duration_ms is not None and duration_ms >= 0 + + # Stack must be empty after the final chunk. + assert TM.get_current_span_id() is None @pytest.mark.asyncio async def test_keyword_identifiers_emission_default( @@ -5853,8 +6053,10 @@ async def test_trace_id_continuity_no_ambient_span(self, callback_context): TM = bigquery_agent_analytics_plugin.TraceManager - # Create a real TracerProvider and patch the plugin's module-level - # tracer so push_span creates valid spans with proper trace_ids. + # Wire a real TracerProvider with an in-memory exporter so we can + # also assert the plugin path does NOT export anything through it. + # (push_span no longer creates OTel spans — see _SpanRecord; the + # exporter is here as a regression guard, not a span source.) exporter = InMemorySpanExporter() provider = SdkProvider() provider.add_span_processor(SimpleSpanProcessor(exporter)) @@ -6176,8 +6378,11 @@ async def test_starting_completed_same_span_with_ambient( assert len(agent_starting) == 1 assert len(agent_completed) == 1 - # Both events must share the same span_id (the ambient - # invoke_agent span) — no plugin-synthetic override. + # Both events must share the same span_id (the plugin-internal + # agent span pushed by before_agent_callback and popped by + # after_agent_callback). The lifecycle-pair invariant holds + # regardless of whether the id comes from a plugin-minted hex + # string or an ambient OTel span. assert agent_starting[0]["span_id"] == agent_completed[0]["span_id"] assert ( agent_starting[0]["parent_span_id"] @@ -6362,8 +6567,16 @@ def test_ensure_invocation_span_clears_stale_records(self, callback_context): provider.shutdown() - def test_clear_stack_ends_owned_spans(self, callback_context): - """clear_stack() ends all owned spans.""" + def test_clear_stack_does_not_export_spans(self, callback_context): + """``clear_stack()`` clears the internal records but does NOT + export any OTel spans (issue #94 regression guard). + + Pre-fix, ``clear_stack()`` called ``record.span.end()`` for every + owned record, which delivered the now-finished span to whatever + exporter the host had wired — duplicating it next to the + framework's real span in Cloud Trace. Post-fix the plugin owns + no OTel span at all; ``clear_stack()`` only resets the contextvar. + """ from opentelemetry.sdk.trace import TracerProvider as SdkProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter @@ -6384,6 +6597,8 @@ def test_clear_stack_ends_owned_spans(self, callback_context): records = list(bigquery_agent_analytics_plugin._span_records_ctx.get()) assert all(r.owns_span for r in records) + # No exported spans yet (the plugin never creates any). + assert exporter.get_finished_spans() == () TM.clear_stack() @@ -6391,9 +6606,12 @@ def test_clear_stack_ends_owned_spans(self, callback_context): result = bigquery_agent_analytics_plugin._span_records_ctx.get() assert result == [] - # Both owned spans should have been ended (exported). - exported = exporter.get_finished_spans() - assert len(exported) == 2 + # Still no exported spans — the regression guard for #94. + assert exporter.get_finished_spans() == (), ( + "clear_stack() must not export OTel spans; any owned span" + " would surface as a duplicate in Cloud Trace alongside the" + " framework's real spans (issue #94)." + ) provider.shutdown() @@ -7656,8 +7874,13 @@ async def test_skips_long_running_tool_events( bq_plugin_inst, mock_write_client, invocation_context, + dummy_arrow_schema, ): - """Long-running tool events are not logged as AGENT_RESPONSE.""" + """Long-running tool events are not logged as AGENT_RESPONSE. + + They DO emit TOOL_PAUSED — here via the unmatched-id fallback, since + the function_call part has no id matching the long_running_tool_id. + """ fc = types.FunctionCall(name="long_tool", args={}) event = event_lib.Event( author="agent", @@ -7665,11 +7888,16 @@ async def test_skips_long_running_tool_events( long_running_tool_ids={"call-1"}, ) + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) await bq_plugin_inst.on_event_callback( invocation_context=invocation_context, event=event ) await asyncio.sleep(0.05) - assert mock_write_client.append_rows.call_count == 0 + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + types_emitted = [r["event_type"] for r in rows] + assert "AGENT_RESPONSE" not in types_emitted + # The pause is still observable via the fallback TOOL_PAUSED row. + assert types_emitted == ["TOOL_PAUSED"] @pytest.mark.asyncio async def test_skips_thought_only_events( @@ -7787,3 +8015,730 @@ async def test_skips_executable_code_only_events( ) await asyncio.sleep(0.05) assert mock_write_client.append_rows.call_count == 0 + + +# ----------------------------------------------------------------------------- +# ADK 2.0 minimum producer cut +# +# Coverage matrix: +# A1 / A2 attributes.adk.{schema_version, app_name} on every row +# A3 attributes.adk.source_event_id on Event-originating rows +# C1 attributes.adk.node {path, run_id, parent_path} +# C2 attributes.adk.branch +# C3 attributes.adk.scope {id, kind} +# C4 AGENT_TRANSFER emit +# C5 EVENT_COMPACTION emit (preserves fractional float epoch) +# C6 AGENT_STATE_CHECKPOINT emit (both shapes) + id-stabilization +# C7 TOOL_PAUSED with pause_kind / function_call_id +# HITL non-routing to TOOL_COMPLETED +# user-message TOOL_COMPLETED with pause_kind='tool' +# C8 attributes.adk.{route, render_ui_widgets, rewind_before_invocation_id} +# D1 on_state_change_callback removed +# ----------------------------------------------------------------------------- + + +def test_derive_scope_unscoped(): + """C3: None isolation_scope → scope = null.""" + assert bigquery_agent_analytics_plugin._derive_scope(None) is None + + +def test_derive_scope_node_run_bare(): + """C3: bare 'name@run_id' classifies as node_run (not function_call).""" + scope = bigquery_agent_analytics_plugin._derive_scope("loopA@42") + assert scope == {"id": "loopA@42", "kind": "node_run"} + + +def test_derive_scope_node_run_path(): + """C3: 'parent/name@run_id' classifies as node_run.""" + scope = bigquery_agent_analytics_plugin._derive_scope("wf/A@1/B@2") + assert scope == {"id": "wf/A@1/B@2", "kind": "node_run"} + + +def test_derive_scope_function_call_provider_id(): + """C3: model-provided FC IDs (call_*, toolu_*) classify as function_call.""" + for fc_id in ("call_abc123", "toolu_xyz", "adk-fc-1"): + scope = bigquery_agent_analytics_plugin._derive_scope(fc_id) + assert scope == {"id": fc_id, "kind": "function_call"}, fc_id + + +def test_derive_scope_empty_string_unknown(): + """C3: empty/non-string anomalies classify as unknown.""" + scope = bigquery_agent_analytics_plugin._derive_scope("") + assert scope == {"id": "", "kind": "unknown"} + + +def test_d1_on_state_change_callback_removed(): + """D1: the deprecated stub is gone from the public surface.""" + assert not hasattr( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin, + "on_state_change_callback", + ) + + +class TestAdkEnvelope: + """A1 / A2 / A3 / C1 / C2 / C3 / C8 envelope shape on emitted rows.""" + + @pytest.mark.asyncio + async def test_envelope_on_non_event_row( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + """USER_MESSAGE_RECEIVED has no source Event → A1/A2 only, A3/C1/C2/C3 null.""" + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_user_message_callback( + invocation_context=invocation_context, + user_message=types.Content(role="user", parts=[types.Part(text="hi")]), + ) + await asyncio.sleep(0.01) + log_entry = await _get_captured_event_dict_async( + mock_write_client, dummy_arrow_schema + ) + _assert_common_fields(log_entry, "USER_MESSAGE_RECEIVED") + attributes = json.loads(log_entry["attributes"]) + adk = attributes["adk"] + # A1: schema_version always present. + assert adk["schema_version"] == ( + bigquery_agent_analytics_plugin._ADK_ENVELOPE_SCHEMA_VERSION + ) + # A2: app_name always present (from session). + assert adk["app_name"] == "test_app" + # A3 / C1 / C2 / C3 absent on rows without an originating Event. + assert "source_event_id" not in adk + assert "node" not in adk + assert "branch" not in adk + assert "scope" not in adk + + @pytest.mark.asyncio + async def test_envelope_on_event_row( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + """STATE_DELTA from on_event_callback carries the full envelope.""" + state_delta = {"k": "v"} + event = event_lib.Event( + author="agent_a", + branch="branch-x", + actions=event_actions_lib.EventActions(state_delta=state_delta), + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.01) + log_entry = await _get_captured_event_dict_async( + mock_write_client, dummy_arrow_schema + ) + _assert_common_fields(log_entry, "STATE_DELTA") + attributes = json.loads(log_entry["attributes"]) + adk = attributes["adk"] + assert adk["schema_version"] == ( + bigquery_agent_analytics_plugin._ADK_ENVELOPE_SCHEMA_VERSION + ) + assert adk["app_name"] == "test_app" + # A3: real Event.id (model_post_init auto-assigns a UUID). + assert adk["source_event_id"] == event.id + assert len(event.id) == 36 # sanity + # C2: branch passthrough. + assert adk["branch"] == "branch-x" + # C1: node defaults to path="" with parent_path=null (no synthesis). + assert adk["node"]["path"] == "" + assert adk["node"]["parent_path"] is None + + @pytest.mark.asyncio + async def test_envelope_node_with_parent_path( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + """C1: parent_path is everything before the final '/segment@run_id'.""" + event = event_lib.Event( + author="agent_b", + actions=event_actions_lib.EventActions(state_delta={"k": "v"}), + ) + event.node_info.path = "wf/A@1/B@2" + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.01) + log_entry = await _get_captured_event_dict_async( + mock_write_client, dummy_arrow_schema + ) + adk = json.loads(log_entry["attributes"])["adk"] + assert adk["node"]["path"] == "wf/A@1/B@2" + assert adk["node"]["parent_path"] == "wf/A@1" + + +class TestC4AgentTransfer: + + @pytest.mark.asyncio + async def test_agent_transfer_emits_from_to_payload( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + event = event_lib.Event( + author="root_agent", + actions=event_actions_lib.EventActions( + transfer_to_agent="specialist_agent" + ), + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.01) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + transfers = [r for r in rows if r["event_type"] == "AGENT_TRANSFER"] + assert len(transfers) == 1 + content = json.loads(transfers[0]["content"]) + assert content == { + "from_agent": "root_agent", + "to_agent": "specialist_agent", + } + + +class TestC5EventCompaction: + + @pytest.mark.asyncio + async def test_event_compaction_preserves_float_precision( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + """C5: fractional float-epoch seconds must survive the producer.""" + compaction = event_actions_lib.EventCompaction( + start_timestamp=1700000000.125, + end_timestamp=1700000003.875, + compacted_content=types.Content( + role="model", parts=[types.Part(text="summary")] + ), + ) + event = event_lib.Event( + author="agent", + actions=event_actions_lib.EventActions(compaction=compaction), + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.01) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + compactions = [r for r in rows if r["event_type"] == "EVENT_COMPACTION"] + assert len(compactions) == 1 + content = json.loads(compactions[0]["content"]) + assert content["start_timestamp"] == 1700000000.125 + assert content["end_timestamp"] == 1700000003.875 + assert content["start_timestamp"] != int(content["start_timestamp"]) + + +class TestC6AgentStateCheckpoint: + + @pytest.mark.asyncio + async def test_checkpoint_state_only( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + """{agent_state: {...}, end_of_agent: None} emits a CHECKPOINT row.""" + event = event_lib.Event( + author="agent", + actions=event_actions_lib.EventActions( + agent_state={"step": 3, "ctx": "abc"} + ), + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.01) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + cps = [r for r in rows if r["event_type"] == "AGENT_STATE_CHECKPOINT"] + assert len(cps) == 1 + content = json.loads(cps[0]["content"]) + assert content["agent_state"] == {"step": 3, "ctx": "abc"} + assert content["end_of_agent"] is False + + @pytest.mark.asyncio + async def test_checkpoint_end_of_agent_only( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + """{agent_state: None, end_of_agent: True} is a valid CHECKPOINT shape.""" + event = event_lib.Event( + author="agent", + actions=event_actions_lib.EventActions(end_of_agent=True), + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.01) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + cps = [r for r in rows if r["event_type"] == "AGENT_STATE_CHECKPOINT"] + assert len(cps) == 1 + content = json.loads(cps[0]["content"]) + assert content["agent_state"] is None + assert content["end_of_agent"] is True + + @pytest.mark.asyncio + async def test_checkpoint_carries_real_source_event_id( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + """v3 regression guard: Event.model_post_init auto-assigns id, so a + checkpoint Event constructed without explicit id still surfaces a real + 36-char UUID in attributes.adk.source_event_id.""" + event = event_lib.Event( + author="agent", + actions=event_actions_lib.EventActions(end_of_agent=True), + ) + assert event.id and len(event.id) == 36 + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.01) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + cps = [r for r in rows if r["event_type"] == "AGENT_STATE_CHECKPOINT"] + assert len(cps) == 1 + adk = json.loads(cps[0]["attributes"])["adk"] + assert adk["source_event_id"] == event.id + + +class TestC7ToolPauseAndComplete: + + @pytest.mark.asyncio + async def test_tool_paused_non_hitl_pause_kind_tool( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + fc = types.FunctionCall( + id="call-1", name="long_running_search", args={"q": "x"} + ) + event = event_lib.Event( + author="agent", + content=types.Content( + role="model", parts=[types.Part(function_call=fc)] + ), + long_running_tool_ids={"call-1"}, + actions=event_actions_lib.EventActions(), + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.01) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + pauses = [r for r in rows if r["event_type"] == "TOOL_PAUSED"] + assert len(pauses) == 1 + # C7 pair keys live UNDER ``attributes.adk`` so the consumer SQL on + # ``JSON_VALUE(attributes, '$.adk.function_call_id')`` resolves. + adk = json.loads(pauses[0]["attributes"])["adk"] + assert adk["pause_kind"] == "tool" + assert adk["function_call_id"] == "call-1" + + @pytest.mark.asyncio + async def test_tool_paused_hitl_pause_kind( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + """C7: HITL long-running call → pause_kind derived from NAME, not id.""" + fc = types.FunctionCall( + id="call-hitl-1", name="adk_request_confirmation", args={} + ) + event = event_lib.Event( + author="agent", + content=types.Content( + role="model", parts=[types.Part(function_call=fc)] + ), + long_running_tool_ids={"call-hitl-1"}, + actions=event_actions_lib.EventActions(), + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.01) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + pauses = [r for r in rows if r["event_type"] == "TOOL_PAUSED"] + assert len(pauses) == 1 + adk = json.loads(pauses[0]["attributes"])["adk"] + assert adk["pause_kind"] == "hitl_confirmation" + assert adk["function_call_id"] == "call-hitl-1" + + @pytest.mark.asyncio + async def test_user_message_function_response_emits_tool_completed( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + """C7: non-HITL function_response in a user message → TOOL_COMPLETED + with pause_kind='tool' (this is the long-running resume path).""" + fr = types.FunctionResponse( + id="call-1", name="long_running_search", response={"hits": 7} + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_user_message_callback( + invocation_context=invocation_context, + user_message=types.Content( + role="user", parts=[types.Part(function_response=fr)] + ), + ) + await asyncio.sleep(0.01) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + completed = [r for r in rows if r["event_type"] == "TOOL_COMPLETED"] + assert len(completed) == 1 + adk = json.loads(completed[0]["attributes"])["adk"] + assert adk["pause_kind"] == "tool" + assert adk["function_call_id"] == "call-1" + + @pytest.mark.asyncio + async def test_hitl_user_message_does_not_emit_tool_completed( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + """C7 HITL non-routing: an adk_request_confirmation function_response in + a user message emits ONLY HITL_CONFIRMATION_REQUEST_COMPLETED, never + TOOL_COMPLETED.""" + fr = types.FunctionResponse( + id="call-hitl-1", + name="adk_request_confirmation", + response={"approved": True}, + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_user_message_callback( + invocation_context=invocation_context, + user_message=types.Content( + role="user", parts=[types.Part(function_response=fr)] + ), + ) + await asyncio.sleep(0.01) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + types_emitted = {r["event_type"] for r in rows} + assert "HITL_CONFIRMATION_REQUEST_COMPLETED" in types_emitted + assert "TOOL_COMPLETED" not in types_emitted + + +class TestC8ActionAttributes: + + @pytest.mark.asyncio + async def test_route_and_rewind_flat_under_attributes_adk( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + """C8: route / rewind_before_invocation_id mirror under + attributes.adk.* (flat-with-prefix, NOT nested under .actions.).""" + event = event_lib.Event( + author="agent", + actions=event_actions_lib.EventActions( + state_delta={"k": "v"}, # to ensure an emit happens + route="branch_b", + rewind_before_invocation_id="inv-earlier", + ), + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.01) + log_entry = await _get_captured_event_dict_async( + mock_write_client, dummy_arrow_schema + ) + adk = json.loads(log_entry["attributes"])["adk"] + # Flat-with-prefix mirror under attributes.adk.*. + assert adk["route"] == "branch_b" + assert adk["rewind_before_invocation_id"] == "inv-earlier" + # Not nested under .actions. + assert "actions" not in adk + + +class TestViewDefsRegistration: + """The plugin's own per-event-type view defs cover the new types.""" + + def test_new_event_types_registered_in_view_defs(self): + defs = bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS + for event_type in ( + "AGENT_TRANSFER", + "EVENT_COMPACTION", + "AGENT_STATE_CHECKPOINT", + "TOOL_PAUSED", + ): + assert event_type in defs, f"{event_type} missing from _EVENT_VIEW_DEFS" + assert isinstance(defs[event_type], list) + + def test_tool_paused_view_extracts_pair_keys(self): + cols = "\n".join( + bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS["TOOL_PAUSED"] + ) + assert "$.adk.pause_kind" in cols + assert "$.adk.function_call_id" in cols + + def test_compaction_view_preserves_float_and_widens(self): + cols = "\n".join( + bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS["EVENT_COMPACTION"] + ) + # Float passthrough for diagnostics + TIMESTAMP_MICROS widening + # (TIMESTAMP_SECONDS would truncate fractional windows). + assert "AS FLOAT64) AS start_seconds" in cols + assert "TIMESTAMP_MICROS" in cols + assert "TIMESTAMP_SECONDS" not in cols + + def test_tool_completed_view_exposes_pair_keys(self): + """v_tool_completed can do the pause/completion join end-to-end.""" + cols = "\n".join( + bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS["TOOL_COMPLETED"] + ) + assert "$.adk.pause_kind" in cols + assert "$.adk.function_call_id" in cols + + def test_checkpoint_view_exposes_agent_state_type(self): + """v_agent_state_checkpoint discriminates explicit JSON null from + object-valued agent_state via JSON_TYPE(JSON_QUERY(...)).""" + cols = "\n".join( + bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS[ + "AGENT_STATE_CHECKPOINT" + ] + ) + assert "JSON_TYPE(JSON_QUERY(content," in cols + assert "AS agent_state_type" in cols + + +class TestUnmatchedLongRunningIdFallback: + + @pytest.mark.asyncio + async def test_unmatched_long_running_id_emits_tool_paused( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + caplog, + ): + """A long_running_tool_id with no matching function_call part still + emits a pairable TOOL_PAUSED row with pause_kind='tool' + warning.""" + event = event_lib.Event( + author="agent", + content=types.Content( + role="model", parts=[types.Part(text="thinking...")] + ), + long_running_tool_ids={"orphan-pause-1"}, + actions=event_actions_lib.EventActions(), + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + with caplog.at_level("WARNING"): + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.01) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + pauses = [r for r in rows if r["event_type"] == "TOOL_PAUSED"] + assert len(pauses) == 1 + adk = json.loads(pauses[0]["attributes"])["adk"] + assert adk["pause_kind"] == "tool" + assert adk["function_call_id"] == "orphan-pause-1" + assert any( + "no matching function_call part" in rec.message + for rec in caplog.records + ) + + @pytest.mark.asyncio + async def test_matched_id_not_double_emitted_by_fallback( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + """An id with a matching part emits exactly one TOOL_PAUSED row.""" + fc = types.FunctionCall(id="call-1", name="long_search", args={}) + event = event_lib.Event( + author="agent", + content=types.Content( + role="model", parts=[types.Part(function_call=fc)] + ), + long_running_tool_ids={"call-1"}, + actions=event_actions_lib.EventActions(), + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.01) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + pauses = [r for r in rows if r["event_type"] == "TOOL_PAUSED"] + assert len(pauses) == 1 + + +class TestAgentlessInvocationContext: + """Workflow-driven invocations can have InvocationContext.agent=None; + rows must not be dropped (ReadonlyContext.agent_name raises + AttributeError in that case).""" + + @pytest.fixture + def agentless_invocation_context(self, mock_session): + mock_session_service = mock.create_autospec( + base_session_service_lib.BaseSessionService, + instance=True, + spec_set=True, + ) + mock_plugin_manager = mock.create_autospec( + plugin_manager_lib.PluginManager, instance=True, spec_set=True + ) + return InvocationContext( + agent=None, + session=mock_session, + invocation_id="inv-workflow-1", + session_service=mock_session_service, + plugin_manager=mock_plugin_manager, + ) + + @pytest.mark.asyncio + async def test_event_row_uses_source_event_author( + self, + bq_plugin_inst, + mock_write_client, + agentless_invocation_context, + dummy_arrow_schema, + ): + """Event-originating row falls back to event.author for the agent + column instead of being dropped.""" + event = event_lib.Event( + author="workflow_node_b", + actions=event_actions_lib.EventActions(state_delta={"k": "v"}), + ) + bigquery_agent_analytics_plugin.TraceManager.push_span( + agentless_invocation_context + ) + await bq_plugin_inst.on_event_callback( + invocation_context=agentless_invocation_context, event=event + ) + await asyncio.sleep(0.01) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + assert len(rows) == 1 + assert rows[0]["event_type"] == "STATE_DELTA" + assert rows[0]["agent"] == "workflow_node_b" + + @pytest.mark.asyncio + async def test_workflow_event_types_not_dropped( + self, + bq_plugin_inst, + mock_write_client, + agentless_invocation_context, + dummy_arrow_schema, + ): + """The ADK 2.0 workflow-centric event types (the ones most likely + to fire with agent=None) land instead of being dropped.""" + event = event_lib.Event( + author="supervisor_node", + actions=event_actions_lib.EventActions( + transfer_to_agent="specialist", + agent_state={"step": 1}, + ), + ) + bigquery_agent_analytics_plugin.TraceManager.push_span( + agentless_invocation_context + ) + await bq_plugin_inst.on_event_callback( + invocation_context=agentless_invocation_context, event=event + ) + await asyncio.sleep(0.01) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + types_emitted = {r["event_type"] for r in rows} + assert "AGENT_TRANSFER" in types_emitted + assert "AGENT_STATE_CHECKPOINT" in types_emitted + assert all(r["agent"] == "supervisor_node" for r in rows) + + @pytest.mark.asyncio + async def test_callback_only_row_gets_null_agent( + self, + bq_plugin_inst, + mock_write_client, + agentless_invocation_context, + dummy_arrow_schema, + ): + """Row with no source Event (user-message path) gets agent=null + rather than being dropped.""" + bigquery_agent_analytics_plugin.TraceManager.push_span( + agentless_invocation_context + ) + await bq_plugin_inst.on_user_message_callback( + invocation_context=agentless_invocation_context, + user_message=types.Content(role="user", parts=[types.Part(text="hi")]), + ) + await asyncio.sleep(0.01) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + assert len(rows) == 1 + assert rows[0]["event_type"] == "USER_MESSAGE_RECEIVED" + assert rows[0]["agent"] is None + + @pytest.mark.asyncio + async def test_event_author_wins_even_if_core_returns_sentinel( + self, + bq_plugin_inst, + mock_write_client, + agentless_invocation_context, + dummy_arrow_schema, + monkeypatch, + ): + """Simulates the core sentinel fix for the no-agent case + (ReadonlyContext.agent_name returning 'unknown' instead of + raising): the row must still use Event.author because the plugin + derives the label from the underlying invocation context, not from + ReadonlyContext.agent_name.""" + from google.adk.agents import readonly_context as readonly_context_lib + + monkeypatch.setattr( + readonly_context_lib.ReadonlyContext, + "agent_name", + property(lambda self: "unknown"), + ) + event = event_lib.Event( + author="workflow_node_c", + actions=event_actions_lib.EventActions(state_delta={"k": "v"}), + ) + bigquery_agent_analytics_plugin.TraceManager.push_span( + agentless_invocation_context + ) + await bq_plugin_inst.on_event_callback( + invocation_context=agentless_invocation_context, event=event + ) + await asyncio.sleep(0.01) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + assert len(rows) == 1 + assert rows[0]["agent"] == "workflow_node_c" + assert rows[0]["agent"] != "unknown"