Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 62 additions & 8 deletions src/google/adk/plugins/bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3241,6 +3241,53 @@ async def _log_event(

# --- UPDATED CALLBACKS FOR V1 PARITY ---

def _resolve_same_session_pause_orphan(
self,
invocation_context: "InvocationContext",
function_call_id: Optional[str],
) -> Optional[bool]:
"""Same-session resolution of ``pause_orphan`` for a long-running resume.

Scans the in-memory session history for the originating paused
function call. This is a hot-path-safe, zero-I/O check: it never
issues a session-service or BigQuery read, so it adds no latency to
the resume path.

Returns:
* ``False`` — the originating paused call is present in this
session's in-memory history, so this completion is definitely
not an orphan.
* ``None`` — unknown / not yet settled: there is no id to pair on,
no session history is available, or the pause was not found in
the (possibly trimmed) in-memory history. Callers omit the
``pause_orphan`` attribute in this case so consumers read it as
SQL NULL ("not yet determined").

Never returns ``True``. Declaring a true orphan requires the settled
fallback (an off-hot-path session-service / BigQuery reconciliation),
which is intentionally out of scope here so the resume path stays
free of added reads. A persistent ``None`` is what that future
fallback upgrades to ``True``.
"""
if not function_call_id:
return None
session = getattr(invocation_context, "session", None)
events = getattr(session, "events", None) if session is not None else None
if not events:
return None
for event in events:
ids = getattr(event, "long_running_tool_ids", None)
if not ids or function_call_id not in ids:
continue
content = getattr(event, "content", None)
parts = getattr(content, "parts", None) if content is not None else None
for part in parts or ():
function_call = getattr(part, "function_call", None)
if function_call is not None and function_call.id == function_call_id:
# The originating paused call is in this session's history.
return False
return None

@_safe_callback
async def on_user_message_callback(
self,
Expand Down Expand Up @@ -3300,26 +3347,33 @@ async def on_user_message_callback(
# tool calls complete inside the agent run via
# after_tool_callback, so a function_response inside a user
# message is the resume side of a previously-paused tool.
# Stamp the pair keys; pause_orphan / registry semantics
# are intentionally deferred.
if not part.function_response.id:
logger.debug(
"User-message function_response for tool %s has no id;"
" the resulting TOOL_COMPLETED row cannot pair with a"
" TOOL_PAUSED row.",
part.function_response.name,
)
adk_extras = {
"pause_kind": "tool",
"function_call_id": part.function_response.id,
}
# pause_orphan: stamped False only when this session's history
# proves the completion pairs with a real pause. Omitted (->
# SQL NULL = "unknown / not yet settled") otherwise. True is
# reserved for the off-hot-path settled fallback so the resume
# path stays free of session-service / BigQuery reads.
pause_orphan = self._resolve_same_session_pause_orphan(
invocation_context, part.function_response.id
)
if pause_orphan is not None:
adk_extras["pause_orphan"] = pause_orphan
await self._log_event(
"TOOL_COMPLETED",
callback_ctx,
raw_content=content_dict,
is_truncated=is_truncated,
event_data=EventData(
adk_extras={
"pause_kind": "tool",
"function_call_id": part.function_response.id,
},
),
event_data=EventData(adk_extras=adk_extras),
)

@_safe_callback
Expand Down
146 changes: 146 additions & 0 deletions tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8815,3 +8815,149 @@ async def test_matched_id_not_double_emitted_by_fallback(
rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema)
pauses = [r for r in rows if r["event_type"] == "TOOL_PAUSED"]
assert len(pauses) == 1


class TestPauseOrphanSameSession:
"""#206 PR 1 — hot-path same-session pause resolution.

``_resolve_same_session_pause_orphan`` returns False only when the
originating pause is present in the in-memory session history, None
otherwise (unknown / not yet settled), and never True. The emit path
stamps ``pause_orphan`` only when the resolver returns a real bool.
"""

def _plugin(self):
return bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
project_id=PROJECT_ID, dataset_id=DATASET_ID, table_id=TABLE_ID
)

def _pause_event(self, fc_id):
fc = types.FunctionCall(id=fc_id, name="long_search", args={})
return event_lib.Event(
author="agent",
content=types.Content(role="model", parts=[types.Part(function_call=fc)]),
long_running_tool_ids={fc_id},
actions=event_actions_lib.EventActions(),
)

def _ic(self, events):
return types.SimpleNamespace(
session=types.SimpleNamespace(events=events)
)

# -- resolver: the only path that returns False --
def test_found_in_session_history_is_not_orphan(self):
plugin = self._plugin()
ic = self._ic([self._pause_event("call-1")])
assert plugin._resolve_same_session_pause_orphan(ic, "call-1") is False

# -- everything else is None ("unknown"), never True --
def test_not_found_is_unknown(self):
plugin = self._plugin()
ic = self._ic([self._pause_event("other-call")])
assert plugin._resolve_same_session_pause_orphan(ic, "call-1") is None

def test_empty_history_is_unknown(self):
plugin = self._plugin()
assert plugin._resolve_same_session_pause_orphan(self._ic([]), "call-1") is None

def test_missing_function_call_id_is_unknown(self):
plugin = self._plugin()
ic = self._ic([self._pause_event("call-1")])
assert plugin._resolve_same_session_pause_orphan(ic, None) is None
assert plugin._resolve_same_session_pause_orphan(ic, "") is None

def test_id_in_long_running_but_no_matching_part_is_unknown(self):
# long_running_tool_ids carries the id but no function_call part
# actually pairs it — cannot prove the pause, so unknown.
plugin = self._plugin()
ev = event_lib.Event(
author="agent",
content=types.Content(role="model", parts=[types.Part(text="hi")]),
long_running_tool_ids={"call-1"},
actions=event_actions_lib.EventActions(),
)
assert plugin._resolve_same_session_pause_orphan(self._ic([ev]), "call-1") is None

def test_no_session_is_unknown(self):
plugin = self._plugin()
ic = types.SimpleNamespace(session=None)
assert plugin._resolve_same_session_pause_orphan(ic, "call-1") is None

def test_never_returns_true(self):
# The resolver must never declare a true orphan in PR 1 — that is
# reserved for the off-hot-path settled fallback.
plugin = self._plugin()
for ic in (self._ic([]), self._ic([self._pause_event("x")]),
types.SimpleNamespace(session=None)):
assert plugin._resolve_same_session_pause_orphan(ic, "call-1") is not True

def test_resolver_does_no_session_service_read(self):
# Hot-path guarantee: the resolver must not call session_service /
# get_session. A session_service whose every attr raises proves the
# resolver only touches the in-memory session.events.
plugin = self._plugin()

class _Exploding:
def __getattr__(self, name):
raise AssertionError(f"resolver touched session_service.{name}")

ic = types.SimpleNamespace(
session=types.SimpleNamespace(events=[self._pause_event("call-1")]),
session_service=_Exploding(),
)
assert plugin._resolve_same_session_pause_orphan(ic, "call-1") is False

# -- emit path: pause_orphan lands on the TOOL_COMPLETED row --
@pytest.mark.asyncio
async def test_emit_stamps_false_when_pause_in_history(
self, bq_plugin_inst, mock_write_client, invocation_context,
dummy_arrow_schema,
):
type(invocation_context.session).events = mock.PropertyMock(
return_value=[self._pause_event("call-1")]
)
user_message = types.Content(
role="user",
parts=[types.Part(function_response=types.FunctionResponse(
id="call-1", name="long_search", response={"ok": True}))],
)
bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context)
await bq_plugin_inst.on_user_message_callback(
invocation_context=invocation_context, user_message=user_message
)
await bq_plugin_inst.flush()
rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema)
completed = [r for r in rows if r["event_type"] == "TOOL_COMPLETED"]
assert len(completed) == 1
adk = json.loads(completed[0]["attributes"])["adk"]
assert adk["function_call_id"] == "call-1"
assert adk["pause_kind"] == "tool"
assert adk["pause_orphan"] is False

@pytest.mark.asyncio
async def test_emit_omits_pause_orphan_when_unknown(
self, bq_plugin_inst, mock_write_client, invocation_context,
dummy_arrow_schema,
):
# No matching pause in history -> pause_orphan omitted (SQL NULL),
# never stamped true.
type(invocation_context.session).events = mock.PropertyMock(
return_value=[]
)
user_message = types.Content(
role="user",
parts=[types.Part(function_response=types.FunctionResponse(
id="call-1", name="long_search", response={"ok": True}))],
)
bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context)
await bq_plugin_inst.on_user_message_callback(
invocation_context=invocation_context, user_message=user_message
)
await bq_plugin_inst.flush()
rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema)
completed = [r for r in rows if r["event_type"] == "TOOL_COMPLETED"]
assert len(completed) == 1
adk = json.loads(completed[0]["attributes"])["adk"]
assert adk["function_call_id"] == "call-1"
assert "pause_orphan" not in adk