Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 83 additions & 91 deletions src/google/adk/plugins/bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,20 +634,35 @@ class BigQueryLoggerConfig:

@dataclass
class _SpanRecord:
"""A single record on the unified span stack.

Consolidates span, id, ownership, and timing into one object
so all stacks stay in sync by construction.

Note: The plugin intentionally does NOT attach its spans to the
ambient OTel context (no ``context.attach``). This prevents the
plugin from corrupting the framework's span hierarchy when an
external OTel exporter (e.g. ``opentelemetry-instrumentation-vertexai``)
is active. See https://github.com/google/adk-python/issues/4561.
"""A single record on the BQAA plugin's internal span stack.

Stores the IDs and timing the plugin needs to populate BigQuery
``span_id`` / ``parent_span_id`` / ``trace_id`` / ``latency_ms``
columns. Crucially, no OpenTelemetry ``Span`` object is held.

Background — prior approach and the bug it caused:
The previous implementation created real OTel spans via
``tracer.start_span(...)`` purely as ID carriers. When the host
application has an OTel exporter configured (notably Agent Engine
with ``GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY=true``), those
plugin-owned spans were exported to Cloud Trace alongside the
framework's real spans — producing a duplicate-span view for
every BQAA-instrumented operation. See haiyuan-eng-google/BQAA-SDK#94.

The plugin already tracked all parent / child relationships on
this internal stack, so the OTel span object was incidental to
correctness. We now store ``trace_id`` directly on each record
(inherited from the ambient OTel span when present, generated
otherwise) and skip span creation entirely. Cross-system
correlation with Cloud Trace still works via ``trace_id``
inheritance.

``attach_current_span`` (which observes the ambient span without
owning one) is unaffected by this change.
"""

span: trace.Span
span_id: str
trace_id: str
owns_span: bool
start_time_ns: int
first_token_time: Optional[float] = None
Expand Down Expand Up @@ -689,17 +704,16 @@ def init_trace(callback_context: CallbackContext) -> None:

@staticmethod
def get_trace_id(callback_context: CallbackContext) -> Optional[str]:
"""Gets the trace ID from the current span or invocation_id."""
"""Gets the trace ID from the current span stack or invocation_id."""
records = _span_records_ctx.get()
if records:
current_span = records[-1].span
if current_span.get_span_context().is_valid:
return format(current_span.get_span_context().trace_id, "032x")
return records[-1].trace_id

# Fallback to OTel context
current_span = trace.get_current_span()
if current_span.get_span_context().is_valid:
return format(current_span.get_span_context().trace_id, "032x")
# Fallback to ambient OTel context (e.g. callbacks fired before
# any plugin span was pushed).
ambient_ctx = trace.get_current_span().get_span_context()
if ambient_ctx.is_valid:
return format(ambient_ctx.trace_id, "032x")

return callback_context.invocation_id

Expand All @@ -708,78 +722,80 @@ def push_span(
callback_context: CallbackContext,
span_name: Optional[str] = "adk-span",
) -> str:
"""Starts a new span and pushes it onto the stack.

The span is created but NOT attached to the ambient OTel context,
so it cannot corrupt the framework's own span hierarchy. The
plugin tracks span_id / parent_span_id internally via its own
contextvar stack.

If OTel is not configured (returning non-recording spans), a UUID
fallback is generated to ensure span_id and parent_span_id are
populated in BigQuery logs.
"""Pushes a BQAA-internal span record onto the stack.

No OpenTelemetry span is created — see ``_SpanRecord`` for
background. The record carries everything the plugin needs to
populate BigQuery columns:

* ``span_id`` — newly generated 16-hex string.
* ``trace_id`` — inherited by precedence:
1. Top of the existing internal stack (keeps every push
within an invocation under one trace_id).
2. Ambient OTel span when valid (e.g. the framework's Runner
span, or an Agent Engine root span) — keeps BigQuery rows
joinable to Cloud Trace via the shared ``trace_id``.
3. A fresh 32-hex value (no ambient context, e.g. unit tests
or non-OTel runtimes).
* ``start_time_ns`` — for the eventual ``latency_ms`` on pop.

``span_name`` is preserved on the signature for API stability but
is no longer used (no OTel span name is set).
"""
del span_name # No-op: kept for API stability; no OTel span is created.
TraceManager.init_trace(callback_context)

# Create the span without attaching it to the ambient context.
# This avoids re-parenting framework spans like ``call_llm``
# or ``execute_tool``. See #4561.
#
# If the internal stack already has a span, create the new span
# as a child so it shares the same trace_id. Without this, each
# ``start_span`` would be an independent root with its own
# trace_id — causing trace_id fracture (see #4645).
records = TraceManager._get_records()
parent_ctx = None
if records and records[-1].span.get_span_context().is_valid:
parent_ctx = trace.set_span_in_context(records[-1].span)
span = tracer.start_span(span_name, context=parent_ctx)

if span.get_span_context().is_valid:
span_id_str = format(span.get_span_context().span_id, "016x")
if records:
trace_id = records[-1].trace_id
else:
span_id_str = uuid.uuid4().hex
ambient_ctx = trace.get_current_span().get_span_context()
if ambient_ctx.is_valid:
trace_id = format(ambient_ctx.trace_id, "032x")
else:
trace_id = uuid.uuid4().hex # 32 hex chars

span_id_str = uuid.uuid4().hex[:16]

record = _SpanRecord(
span=span,
span_id=span_id_str,
trace_id=trace_id,
owns_span=True,
start_time_ns=time.time_ns(),
)

new_records = list(records) + [record]
_span_records_ctx.set(new_records)
_span_records_ctx.set(list(records) + [record])

return span_id_str

@staticmethod
def attach_current_span(
callback_context: CallbackContext,
) -> str:
"""Records the current OTel span on the stack without owning it.
"""Records the ambient OTel span's IDs on the stack without owning it.

The span is NOT re-attached to the ambient context; it is only
tracked internally for span_id / parent_span_id resolution.
No OTel span is created or attached. This path captures the
ambient span's ``trace_id`` / ``span_id`` so plugin-emitted
BigQuery rows correlate with whatever Cloud Trace / external
exporter the host is already running.
"""
TraceManager.init_trace(callback_context)

span = trace.get_current_span()

if span.get_span_context().is_valid:
span_id_str = format(span.get_span_context().span_id, "016x")
ambient_ctx = trace.get_current_span().get_span_context()
if ambient_ctx.is_valid:
span_id_str = format(ambient_ctx.span_id, "016x")
trace_id = format(ambient_ctx.trace_id, "032x")
else:
span_id_str = uuid.uuid4().hex
span_id_str = uuid.uuid4().hex[:16]
trace_id = uuid.uuid4().hex

record = _SpanRecord(
span=span,
span_id=span_id_str,
trace_id=trace_id,
owns_span=False,
start_time_ns=time.time_ns(),
)

records = TraceManager._get_records()
new_records = list(records) + [record]
_span_records_ctx.set(new_records)
_span_records_ctx.set(list(records) + [record])

return span_id_str

Expand Down Expand Up @@ -828,10 +844,10 @@ def ensure_invocation_span(

@staticmethod
def pop_span() -> tuple[Optional[str], Optional[int]]:
"""Ends the current span and pops it from the stack.
"""Pops the top span record from the internal stack.

No ambient OTel context is detached because we never attached
one in the first place (see ``push_span``).
Returns ``(span_id, duration_ms)``. No OTel span is ended
because the plugin no longer creates one (see ``_SpanRecord``).
"""
records = _span_records_ctx.get()
if not records:
Expand All @@ -841,29 +857,13 @@ def pop_span() -> tuple[Optional[str], Optional[int]]:
record = new_records.pop()
_span_records_ctx.set(new_records)

# Calculate duration
duration_ms = None
otel_start = getattr(record.span, "start_time", None)
if isinstance(otel_start, (int, float)) and otel_start:
duration_ms = int((time.time_ns() - otel_start) / 1_000_000)
else:
duration_ms = int((time.time_ns() - record.start_time_ns) / 1_000_000)

if record.owns_span:
record.span.end()

duration_ms = int((time.time_ns() - record.start_time_ns) / 1_000_000)
return record.span_id, duration_ms

@staticmethod
def clear_stack() -> None:
"""Clears all span records. Safety net for cross-invocation cleanup."""
records = _span_records_ctx.get()
if records:
# End any owned spans to avoid OTel resource leaks.
for record in reversed(records):
if record.owns_span:
record.span.end()
_span_records_ctx.set([])
_span_records_ctx.set([])

@staticmethod
def get_current_span_and_parent() -> tuple[Optional[str], Optional[str]]:
Expand Down Expand Up @@ -894,19 +894,11 @@ def get_root_agent_name() -> Optional[str]:

@staticmethod
def get_start_time(span_id: str) -> Optional[float]:
"""Gets start time of a span by ID."""
"""Gets start time of a span by ID (seconds since epoch)."""
records = _span_records_ctx.get()
if records:
for record in reversed(records):
if record.span_id == span_id:
# Try OTel span start_time first
otel_start = getattr(record.span, "start_time", None)
if (
record.span.get_span_context().is_valid
and isinstance(otel_start, (int, float))
and otel_start
):
return otel_start / 1_000_000_000.0
return record.start_time_ns / 1_000_000_000.0
return None

Expand Down
Loading
Loading