diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 36d92bf781d..ef345367fad 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -3241,6 +3241,53 @@ async def _log_event( # --- UPDATED CALLBACKS FOR V1 PARITY --- + def _resolve_same_session_pause_orphan( + self, + invocation_context: "InvocationContext", + function_call_id: Optional[str], + ) -> Optional[bool]: + """Same-session resolution of ``pause_orphan`` for a long-running resume. + + Scans the in-memory session history for the originating paused + function call. This is a hot-path-safe, zero-I/O check: it never + issues a session-service or BigQuery read, so it adds no latency to + the resume path. + + Returns: + * ``False`` — the originating paused call is present in this + session's in-memory history, so this completion is definitely + not an orphan. + * ``None`` — unknown / not yet settled: there is no id to pair on, + no session history is available, or the pause was not found in + the (possibly trimmed) in-memory history. Callers omit the + ``pause_orphan`` attribute in this case so consumers read it as + SQL NULL ("not yet determined"). + + Never returns ``True``. Declaring a true orphan requires the settled + fallback (an off-hot-path session-service / BigQuery reconciliation), + which is intentionally out of scope here so the resume path stays + free of added reads. A persistent ``None`` is what that future + fallback upgrades to ``True``. + """ + if not function_call_id: + return None + session = getattr(invocation_context, "session", None) + events = getattr(session, "events", None) if session is not None else None + if not events: + return None + for event in events: + ids = getattr(event, "long_running_tool_ids", None) + if not ids or function_call_id not in ids: + continue + content = getattr(event, "content", None) + parts = getattr(content, "parts", None) if content is not None else None + for part in parts or (): + function_call = getattr(part, "function_call", None) + if function_call is not None and function_call.id == function_call_id: + # The originating paused call is in this session's history. + return False + return None + @_safe_callback async def on_user_message_callback( self, @@ -3300,8 +3347,6 @@ async def on_user_message_callback( # tool calls complete inside the agent run via # after_tool_callback, so a function_response inside a user # message is the resume side of a previously-paused tool. - # Stamp the pair keys; pause_orphan / registry semantics - # are intentionally deferred. if not part.function_response.id: logger.debug( "User-message function_response for tool %s has no id;" @@ -3309,17 +3354,26 @@ async def on_user_message_callback( " TOOL_PAUSED row.", part.function_response.name, ) + adk_extras = { + "pause_kind": "tool", + "function_call_id": part.function_response.id, + } + # pause_orphan: stamped False only when this session's history + # proves the completion pairs with a real pause. Omitted (-> + # SQL NULL = "unknown / not yet settled") otherwise. True is + # reserved for the off-hot-path settled fallback so the resume + # path stays free of session-service / BigQuery reads. + pause_orphan = self._resolve_same_session_pause_orphan( + invocation_context, part.function_response.id + ) + if pause_orphan is not None: + adk_extras["pause_orphan"] = pause_orphan await self._log_event( "TOOL_COMPLETED", callback_ctx, raw_content=content_dict, is_truncated=is_truncated, - event_data=EventData( - adk_extras={ - "pause_kind": "tool", - "function_call_id": part.function_response.id, - }, - ), + event_data=EventData(adk_extras=adk_extras), ) @_safe_callback diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 7cc8e3600c2..83ac094afc5 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -8815,3 +8815,149 @@ async def test_matched_id_not_double_emitted_by_fallback( rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) pauses = [r for r in rows if r["event_type"] == "TOOL_PAUSED"] assert len(pauses) == 1 + + +class TestPauseOrphanSameSession: + """#206 PR 1 — hot-path same-session pause resolution. + + ``_resolve_same_session_pause_orphan`` returns False only when the + originating pause is present in the in-memory session history, None + otherwise (unknown / not yet settled), and never True. The emit path + stamps ``pause_orphan`` only when the resolver returns a real bool. + """ + + def _plugin(self): + return bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, dataset_id=DATASET_ID, table_id=TABLE_ID + ) + + def _pause_event(self, fc_id): + fc = types.FunctionCall(id=fc_id, name="long_search", args={}) + return event_lib.Event( + author="agent", + content=types.Content(role="model", parts=[types.Part(function_call=fc)]), + long_running_tool_ids={fc_id}, + actions=event_actions_lib.EventActions(), + ) + + def _ic(self, events): + return types.SimpleNamespace( + session=types.SimpleNamespace(events=events) + ) + + # -- resolver: the only path that returns False -- + def test_found_in_session_history_is_not_orphan(self): + plugin = self._plugin() + ic = self._ic([self._pause_event("call-1")]) + assert plugin._resolve_same_session_pause_orphan(ic, "call-1") is False + + # -- everything else is None ("unknown"), never True -- + def test_not_found_is_unknown(self): + plugin = self._plugin() + ic = self._ic([self._pause_event("other-call")]) + assert plugin._resolve_same_session_pause_orphan(ic, "call-1") is None + + def test_empty_history_is_unknown(self): + plugin = self._plugin() + assert plugin._resolve_same_session_pause_orphan(self._ic([]), "call-1") is None + + def test_missing_function_call_id_is_unknown(self): + plugin = self._plugin() + ic = self._ic([self._pause_event("call-1")]) + assert plugin._resolve_same_session_pause_orphan(ic, None) is None + assert plugin._resolve_same_session_pause_orphan(ic, "") is None + + def test_id_in_long_running_but_no_matching_part_is_unknown(self): + # long_running_tool_ids carries the id but no function_call part + # actually pairs it — cannot prove the pause, so unknown. + plugin = self._plugin() + ev = event_lib.Event( + author="agent", + content=types.Content(role="model", parts=[types.Part(text="hi")]), + long_running_tool_ids={"call-1"}, + actions=event_actions_lib.EventActions(), + ) + assert plugin._resolve_same_session_pause_orphan(self._ic([ev]), "call-1") is None + + def test_no_session_is_unknown(self): + plugin = self._plugin() + ic = types.SimpleNamespace(session=None) + assert plugin._resolve_same_session_pause_orphan(ic, "call-1") is None + + def test_never_returns_true(self): + # The resolver must never declare a true orphan in PR 1 — that is + # reserved for the off-hot-path settled fallback. + plugin = self._plugin() + for ic in (self._ic([]), self._ic([self._pause_event("x")]), + types.SimpleNamespace(session=None)): + assert plugin._resolve_same_session_pause_orphan(ic, "call-1") is not True + + def test_resolver_does_no_session_service_read(self): + # Hot-path guarantee: the resolver must not call session_service / + # get_session. A session_service whose every attr raises proves the + # resolver only touches the in-memory session.events. + plugin = self._plugin() + + class _Exploding: + def __getattr__(self, name): + raise AssertionError(f"resolver touched session_service.{name}") + + ic = types.SimpleNamespace( + session=types.SimpleNamespace(events=[self._pause_event("call-1")]), + session_service=_Exploding(), + ) + assert plugin._resolve_same_session_pause_orphan(ic, "call-1") is False + + # -- emit path: pause_orphan lands on the TOOL_COMPLETED row -- + @pytest.mark.asyncio + async def test_emit_stamps_false_when_pause_in_history( + self, bq_plugin_inst, mock_write_client, invocation_context, + dummy_arrow_schema, + ): + type(invocation_context.session).events = mock.PropertyMock( + return_value=[self._pause_event("call-1")] + ) + user_message = types.Content( + role="user", + parts=[types.Part(function_response=types.FunctionResponse( + id="call-1", name="long_search", response={"ok": True}))], + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_user_message_callback( + invocation_context=invocation_context, user_message=user_message + ) + await bq_plugin_inst.flush() + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + completed = [r for r in rows if r["event_type"] == "TOOL_COMPLETED"] + assert len(completed) == 1 + adk = json.loads(completed[0]["attributes"])["adk"] + assert adk["function_call_id"] == "call-1" + assert adk["pause_kind"] == "tool" + assert adk["pause_orphan"] is False + + @pytest.mark.asyncio + async def test_emit_omits_pause_orphan_when_unknown( + self, bq_plugin_inst, mock_write_client, invocation_context, + dummy_arrow_schema, + ): + # No matching pause in history -> pause_orphan omitted (SQL NULL), + # never stamped true. + type(invocation_context.session).events = mock.PropertyMock( + return_value=[] + ) + user_message = types.Content( + role="user", + parts=[types.Part(function_response=types.FunctionResponse( + id="call-1", name="long_search", response={"ok": True}))], + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_user_message_callback( + invocation_context=invocation_context, user_message=user_message + ) + await bq_plugin_inst.flush() + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + completed = [r for r in rows if r["event_type"] == "TOOL_COMPLETED"] + assert len(completed) == 1 + adk = json.loads(completed[0]["attributes"])["adk"] + assert adk["function_call_id"] == "call-1" + assert "pause_orphan" not in adk