From 79f38df9275e307897971d062bd0ed1cc5c44519 Mon Sep 17 00:00:00 2001 From: eternalcuriouslearner Date: Tue, 17 Mar 2026 07:55:37 -0400 Subject: [PATCH 1/8] wip: wrappers for async message wrappers. --- .../anthropic/messages_extractors.py | 86 +++ .../instrumentation/anthropic/wrappers.py | 352 ++++++++---- .../tests/test_async_wrappers.py | 536 ++++++++++++++++++ 3 files changed, 879 insertions(+), 95 deletions(-) create mode 100644 instrumentation-genai/opentelemetry-instrumentation-anthropic/tests/test_async_wrappers.py diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/messages_extractors.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/messages_extractors.py index 907aa08d39..a69f9051a5 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/messages_extractors.py +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/messages_extractors.py @@ -29,6 +29,7 @@ ) from opentelemetry.util.genai.types import ( InputMessage, + LLMInvocation, MessagePart, OutputMessage, ) @@ -37,6 +38,7 @@ from .utils import ( convert_content_to_parts, normalize_finish_reason, + stream_block_state_to_part, ) if TYPE_CHECKING: @@ -55,6 +57,8 @@ Usage, ) + from .utils import StreamBlockState + @dataclass class MessageRequestParams: @@ -153,6 +157,88 @@ def get_output_messages_from_message( ] +def set_invocation_message_response_attributes( + invocation: LLMInvocation, + message: Message | None, + capture_content: bool, +) -> None: + if message is None: + return + + if message.model: + invocation.response_model_name = message.model + + if message.id: + invocation.response_id = message.id + + finish_reason = normalize_finish_reason(message.stop_reason) + if finish_reason: + invocation.finish_reasons = [finish_reason] + + if message.usage: + tokens = extract_usage_tokens(message.usage) + invocation.input_tokens = tokens.input_tokens + invocation.output_tokens = tokens.output_tokens + if tokens.cache_creation_input_tokens is not None: + invocation.attributes[ + GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS + ] = tokens.cache_creation_input_tokens + if tokens.cache_read_input_tokens is not None: + invocation.attributes[GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS] = ( + tokens.cache_read_input_tokens + ) + + if capture_content: + invocation.output_messages = get_output_messages_from_message(message) + + +def set_invocation_stream_response_attributes( + invocation: LLMInvocation, + *, + response_model: str | None, + response_id: str | None, + stop_reason: str | None, + input_tokens: int | None, + output_tokens: int | None, + cache_creation_input_tokens: int | None, + cache_read_input_tokens: int | None, + capture_content: bool, + content_blocks: dict[int, "StreamBlockState"], +) -> None: + if response_model: + invocation.response_model_name = response_model + if response_id: + invocation.response_id = response_id + if stop_reason: + invocation.finish_reasons = [stop_reason] + if input_tokens is not None: + invocation.input_tokens = input_tokens + if output_tokens is not None: + invocation.output_tokens = output_tokens + if cache_creation_input_tokens is not None: + invocation.attributes[ + GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS + ] = cache_creation_input_tokens + if cache_read_input_tokens is not None: + invocation.attributes[ + GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS + ] = cache_read_input_tokens + + if capture_content and content_blocks: + parts: list[MessagePart] = [] + for index in sorted(content_blocks): + part = stream_block_state_to_part(content_blocks[index]) + if part is not None: + parts.append(part) + invocation.output_messages = [ + OutputMessage( + role="assistant", + parts=parts, + finish_reason=stop_reason or "", + ) + ] + + def extract_params( # pylint: disable=too-many-locals *, max_tokens: int | None = None, diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py index 585d70a9f9..9dbc2bd1cb 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py @@ -15,33 +15,37 @@ from __future__ import annotations import logging +from contextlib import AsyncExitStack, ExitStack, contextmanager from types import TracebackType -from typing import TYPE_CHECKING, Callable, Iterator, Optional +from typing import TYPE_CHECKING, Callable, Generator, Generic, Iterator, Optional, TypeVar from opentelemetry.util.genai.handler import TelemetryHandler from opentelemetry.util.genai.types import ( Error, LLMInvocation, - MessagePart, - OutputMessage, ) from .messages_extractors import ( - GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS, - GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS, extract_usage_tokens, - get_output_messages_from_message, + set_invocation_message_response_attributes, + set_invocation_stream_response_attributes, ) from .utils import ( StreamBlockState, create_stream_block_state, normalize_finish_reason, - stream_block_state_to_part, update_stream_block_state, ) if TYPE_CHECKING: - from anthropic._streaming import Stream + from anthropic._streaming import AsyncStream, Stream + from anthropic.lib.streaming._messages import ( # pylint: disable=no-name-in-module + AsyncMessageStream, + AsyncMessageStreamManager, + MessageStream, + MessageStreamEvent, + MessageStreamManager, + ) from anthropic.types import ( Message, MessageDeltaUsage, @@ -51,6 +55,43 @@ _logger = logging.getLogger(__name__) +ResponseT = TypeVar("ResponseT") + + +class _ResponseProxy(Generic[ResponseT]): + def __init__(self, response: ResponseT, finalize: Callable[[], None]): + self._response = response + self._finalize = finalize + + def close(self) -> None: + try: + self._response.close() + finally: + self._finalize() + + def __getattr__(self, name: str): + return getattr(self._response, name) + + +class _AsyncResponseProxy(Generic[ResponseT]): + def __init__(self, response: ResponseT, finalize: Callable[[], None]): + self._response = response + self._finalize = finalize + + def close(self) -> None: + try: + self._response.close() + finally: + self._finalize() + + async def aclose(self) -> None: + try: + await self._response.aclose() + finally: + self._finalize() + + def __getattr__(self, name: str): + return getattr(self._response, name) class MessageWrapper: @@ -62,33 +103,9 @@ def __init__(self, message: Message, capture_content: bool): def extract_into(self, invocation: LLMInvocation) -> None: """Extract response data into the invocation.""" - if self._message.model: - invocation.response_model_name = self._message.model - - if self._message.id: - invocation.response_id = self._message.id - - finish_reason = normalize_finish_reason(self._message.stop_reason) - if finish_reason: - invocation.finish_reasons = [finish_reason] - - if self._message.usage: - tokens = extract_usage_tokens(self._message.usage) - invocation.input_tokens = tokens.input_tokens - invocation.output_tokens = tokens.output_tokens - if tokens.cache_creation_input_tokens is not None: - invocation.attributes[ - GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS - ] = tokens.cache_creation_input_tokens - if tokens.cache_read_input_tokens is not None: - invocation.attributes[GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS] = ( - tokens.cache_read_input_tokens - ) - - if self._capture_content: - invocation.output_messages = get_output_messages_from_message( - self._message - ) + set_invocation_message_response_attributes( + invocation, self._message, self._capture_content + ) @property def message(self) -> Message: @@ -106,9 +123,9 @@ def __init__( invocation: LLMInvocation, capture_content: bool, ): - self._stream = stream - self._handler = handler - self._invocation = invocation + self.stream = stream + self.handler = handler + self.invocation = invocation self._response_id: Optional[str] = None self._response_model: Optional[str] = None self._stop_reason: Optional[str] = None @@ -158,11 +175,12 @@ def _process_chunk(self, chunk: RawMessageStreamEvent) -> None: update_stream_block_state(block, chunk.delta) @staticmethod + @contextmanager def _safe_instrumentation( - callback: Callable[[], object], context: str - ) -> None: + context: str, + ) -> Generator[None, None, None]: try: - callback() + yield except Exception: # pylint: disable=broad-exception-caught _logger.debug( "Anthropic MessagesStreamWrapper instrumentation error in %s", @@ -172,82 +190,61 @@ def _safe_instrumentation( def _set_invocation_response_attributes(self) -> None: """Extract accumulated stream state into the invocation.""" - if self._response_model: - self._invocation.response_model_name = self._response_model - if self._response_id: - self._invocation.response_id = self._response_id - if self._stop_reason: - self._invocation.finish_reasons = [self._stop_reason] - if self._input_tokens is not None: - self._invocation.input_tokens = self._input_tokens - if self._output_tokens is not None: - self._invocation.output_tokens = self._output_tokens - if self._cache_creation_input_tokens is not None: - self._invocation.attributes[ - GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS - ] = self._cache_creation_input_tokens - if self._cache_read_input_tokens is not None: - self._invocation.attributes[ - GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS - ] = self._cache_read_input_tokens - - if self._capture_content and self._content_blocks: - parts: list[MessagePart] = [] - for index in sorted(self._content_blocks): - part = stream_block_state_to_part(self._content_blocks[index]) - if part is not None: - parts.append(part) - self._invocation.output_messages = [ - OutputMessage( - role="assistant", - parts=parts, - finish_reason=self._stop_reason or "", - ) - ] + set_invocation_stream_response_attributes( + self.invocation, + response_model=self._response_model, + response_id=self._response_id, + stop_reason=self._stop_reason, + input_tokens=self._input_tokens, + output_tokens=self._output_tokens, + cache_creation_input_tokens=self._cache_creation_input_tokens, + cache_read_input_tokens=self._cache_read_input_tokens, + capture_content=self._capture_content, + content_blocks=self._content_blocks, + ) def _stop(self) -> None: if self._finalized: return - self._safe_instrumentation( - self._set_invocation_response_attributes, - "response attribute extraction", - ) - self._safe_instrumentation( - lambda: self._handler.stop_llm(self._invocation), - "stop_llm", - ) + with self._safe_instrumentation("response attribute extraction"): + self._set_invocation_response_attributes() + with self._safe_instrumentation("stop_llm"): + self.handler.stop_llm(self.invocation) self._finalized = True def _fail(self, message: str, error_type: type[BaseException]) -> None: if self._finalized: return - self._safe_instrumentation( - lambda: self._handler.fail_llm( - self._invocation, Error(message=message, type=error_type) - ), - "fail_llm", - ) + with self._safe_instrumentation("fail_llm"): + self.handler.fail_llm( + self.invocation, Error(message=message, type=error_type) + ) self._finalized = True def __iter__(self) -> MessagesStreamWrapper: return self def __getattr__(self, name: str) -> object: - return getattr(self._stream, name) + return getattr(self.stream, name) + + @property + def response(self): + response = getattr(self.stream, "response", None) + if response is None: + return None + return _ResponseProxy(response, self._stop) - def __next__(self) -> RawMessageStreamEvent: + def __next__(self) -> "RawMessageStreamEvent | MessageStreamEvent": try: - chunk = next(self._stream) + chunk = next(self.stream) except StopIteration: self._stop() raise except Exception as exc: self._fail(str(exc), type(exc)) raise - self._safe_instrumentation( - lambda: self._process_chunk(chunk), - "stream chunk processing", - ) + with self._safe_instrumentation("stream chunk processing"): + self._process_chunk(chunk) return chunk def __enter__(self) -> MessagesStreamWrapper: @@ -270,6 +267,171 @@ def __exit__( def close(self) -> None: try: - self._stream.close() + self.stream.close() + finally: + self._stop() + + +class AsyncMessagesStreamWrapper(MessagesStreamWrapper): + """Wrapper for async Anthropic Stream that handles telemetry.""" + + stream: "AsyncStream[RawMessageStreamEvent] | AsyncMessageStream" + + async def __aenter__(self) -> "AsyncMessagesStreamWrapper": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool: + try: + if exc_type is not None: + self._fail( + str(exc_val), type(exc_val) if exc_val else Exception + ) + finally: + await self.close() + return False + + async def close(self) -> None: # type: ignore[override] + try: + await self.stream.close() finally: self._stop() + + def __aiter__(self) -> "AsyncMessagesStreamWrapper": + return self + + @property + def response(self): + response = getattr(self.stream, "response", None) + if response is None: + return None + return _AsyncResponseProxy(response, self._stop) + + async def __anext__(self) -> "RawMessageStreamEvent | MessageStreamEvent": + try: + chunk = await self.stream.__anext__() + except StopAsyncIteration: + self._stop() + raise + except Exception as exc: + self._fail(str(exc), type(exc)) + raise + with self._safe_instrumentation("stream chunk processing"): + self._process_chunk(chunk) + return chunk + + +class MessagesStreamManagerWrapper: + """Wrapper for sync Anthropic stream managers.""" + + def __init__( + self, + manager: "MessageStreamManager", + handler: TelemetryHandler, + invocation: LLMInvocation, + capture_content: bool, + ): + self._manager = manager + self._handler = handler + self._invocation = invocation + self._capture_content = capture_content + self._stream_wrapper: MessagesStreamWrapper | None = None + + def __enter__(self) -> MessagesStreamWrapper: + stream = self._manager.__enter__() + self._stream_wrapper = MessagesStreamWrapper( + stream, + self._handler, + self._invocation, + self._capture_content, + ) + return self._stream_wrapper + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool: + suppressed = False + stream_wrapper = self._stream_wrapper + self._stream_wrapper = None + with ExitStack() as cleanup: + if stream_wrapper is not None: + + def finalize_stream_wrapper() -> None: + if suppressed: + stream_wrapper.__exit__(None, None, None) + else: + stream_wrapper.__exit__(exc_type, exc_val, exc_tb) + + cleanup.callback(finalize_stream_wrapper) + suppressed = self._manager.__exit__(exc_type, exc_val, exc_tb) + return suppressed + + def __getattr__(self, name: str) -> object: + return getattr(self._manager, name) + + +class AsyncMessagesStreamManagerWrapper: + """Wrapper for AsyncMessageStreamManager that handles telemetry. + + Wraps AsyncMessageStreamManager from the Anthropic SDK: + https://github.com/anthropics/anthropic-sdk-python/blob/05220bc1c1079fe01f5c4babc007ec7a990859d9/src/anthropic/lib/streaming/_messages.py#L294 + """ + + def __init__( + self, + manager: "AsyncMessageStreamManager", + handler: TelemetryHandler, + invocation: LLMInvocation, + capture_content: bool, + ): + self._manager = manager + self._handler = handler + self._invocation = invocation + self._capture_content = capture_content + self._stream_wrapper: AsyncMessagesStreamWrapper | None = None + + async def __aenter__(self) -> AsyncMessagesStreamWrapper: + msg_stream = await self._manager.__aenter__() + self._stream_wrapper = AsyncMessagesStreamWrapper( + msg_stream, + self._handler, + self._invocation, + self._capture_content, + ) + return self._stream_wrapper + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool: + suppressed = False + stream_wrapper = self._stream_wrapper + self._stream_wrapper = None + async with AsyncExitStack() as cleanup: + if stream_wrapper is not None: + + async def finalize_stream_wrapper() -> None: + if suppressed: + await stream_wrapper.__aexit__(None, None, None) + else: + await stream_wrapper.__aexit__( + exc_type, exc_val, exc_tb + ) + + cleanup.push_async_callback(finalize_stream_wrapper) + suppressed = await self._manager.__aexit__( + exc_type, exc_val, exc_tb + ) + return suppressed + + def __getattr__(self, name: str) -> object: + return getattr(self._manager, name) diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/tests/test_async_wrappers.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/tests/test_async_wrappers.py new file mode 100644 index 0000000000..b2055a525f --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/tests/test_async_wrappers.py @@ -0,0 +1,536 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace + +import pytest + +from opentelemetry.instrumentation.anthropic.wrappers import ( + AsyncMessagesStreamManagerWrapper, + AsyncMessagesStreamWrapper, + MessagesStreamManagerWrapper, + MessagesStreamWrapper, +) + + +def _noop_stop_llm(invocation): + del invocation + + +def _noop_fail_llm(invocation, error): + del invocation + del error + + +def _make_handler(): + return SimpleNamespace( + stop_llm=_noop_stop_llm, + fail_llm=_noop_fail_llm, + ) + + +def _make_invocation(): + return SimpleNamespace(attributes={}, request_model=None) + + +def _make_stream_wrapper(stream, handler=None): + return MessagesStreamWrapper( + stream=stream, + handler=handler or _make_handler(), + invocation=_make_invocation(), + capture_content=False, + ) + + +def _make_async_stream_wrapper(stream, handler=None): + return AsyncMessagesStreamWrapper( + stream=stream, + handler=handler or _make_handler(), + invocation=_make_invocation(), + capture_content=False, + ) + + +class _FakeSyncStream: + def __init__(self, *, events=None, error=None): + self._events = list(events or []) + self._error = error + self.close_calls = 0 + self.response = _FakeSyncResponse() + + def __iter__(self): + return self + + def __next__(self): + if self._events: + return self._events.pop(0) + if self._error is not None: + raise self._error + raise StopIteration + + def close(self): + self.close_calls += 1 + + +class _FakeAsyncStream: + def __init__(self, *, events=None, error=None): + self._events = list(events or []) + self._error = error + self.close_calls = 0 + self.final_message = SimpleNamespace(id="msg_final") + self.response = _FakeAsyncResponse() + + async def __anext__(self): + if self._events: + return self._events.pop(0) + if self._error is not None: + raise self._error + raise StopAsyncIteration + + async def close(self): + self.close_calls += 1 + + async def get_final_message(self): + return self.final_message + + +class _FakeSyncManager: + def __init__(self, stream, suppressed=False, exit_error=None): + self._stream = stream + self._suppressed = suppressed + self._exit_error = exit_error + self.exit_args = None + + def __enter__(self): + return self._stream + + def __exit__(self, exc_type, exc_val, exc_tb): + self.exit_args = (exc_type, exc_val, exc_tb) + if self._exit_error is not None: + raise self._exit_error + return self._suppressed + + +class _FakeAsyncManager: + def __init__(self, stream, suppressed=False, exit_error=None): + self._stream = stream + self._suppressed = suppressed + self._exit_error = exit_error + self.exit_args = None + + async def __aenter__(self): + return self._stream + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_args = (exc_type, exc_val, exc_tb) + if self._exit_error is not None: + raise self._exit_error + return self._suppressed + + +class _FakeStreamWrapper: + def __init__(self): + self.exit_args = None + + def __exit__(self, exc_type, exc_val, exc_tb): + self.exit_args = (exc_type, exc_val, exc_tb) + return False + + +class _FakeAsyncStreamWrapper: + def __init__(self): + self.exit_args = None + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_args = (exc_type, exc_val, exc_tb) + return False + + +class _FakeSyncResponse: + def __init__(self): + self.request_id = "req_sync" + self.close_calls = 0 + + def close(self): + self.close_calls += 1 + + +class _FakeAsyncResponse: + def __init__(self): + self.request_id = "req_async" + self.close_calls = 0 + self.aclose_calls = 0 + + def close(self): + self.close_calls += 1 + + async def aclose(self): + self.aclose_calls += 1 + + +def test_sync_stream_wrapper_exit_closes_without_exception(): + stream = _FakeSyncStream() + wrapper = _make_stream_wrapper(stream) + stopped = [] + + wrapper._stop = lambda: stopped.append(True) + + result = wrapper.__exit__(None, None, None) + + assert result is False + assert stream.close_calls == 1 + assert stopped == [True] + + +def test_sync_stream_wrapper_exit_fails_and_closes_on_exception(): + stream = _FakeSyncStream() + wrapper = _make_stream_wrapper(stream) + stopped = [] + failures = [] + + wrapper._stop = lambda: stopped.append(True) + wrapper._fail = lambda message, error_type: failures.append( + (message, error_type) + ) + + error = ValueError("boom") + result = wrapper.__exit__(ValueError, error, None) + + assert result is False + assert stream.close_calls == 1 + assert stopped == [True] + assert failures == [("boom", ValueError)] + + +def test_sync_stream_wrapper_processes_events_and_stops_on_completion(): + event = SimpleNamespace(type="message_start") + stream = _FakeSyncStream(events=[event]) + wrapper = _make_stream_wrapper(stream) + processed = [] + stopped = [] + + wrapper._process_chunk = processed.append + wrapper._stop = lambda: stopped.append(True) + + result = next(wrapper) + + assert result is event + assert processed == [event] + + with pytest.raises(StopIteration): + next(wrapper) + + assert stopped == [True] + + +def test_sync_stream_wrapper_fails_and_reraises_stream_errors(): + error = ValueError("boom") + stream = _FakeSyncStream(error=error) + wrapper = _make_stream_wrapper(stream) + failures = [] + + wrapper._fail = lambda message, error_type: failures.append( + (message, error_type) + ) + + with pytest.raises(ValueError, match="boom"): + next(wrapper) + + assert failures == [("boom", ValueError)] + + +def test_sync_stream_wrapper_getattr_passthrough(): + stream = _FakeSyncStream() + wrapper = _make_stream_wrapper(stream) + + assert wrapper.response.request_id == "req_sync" + + +def test_sync_stream_response_close_finalizes_wrapper(): + stream = _FakeSyncStream() + wrapper = _make_stream_wrapper(stream) + stopped = [] + + wrapper._stop = lambda: stopped.append(True) + + wrapper.response.close() + + assert stream.response.close_calls == 1 + assert stopped == [True] + + +def test_sync_manager_enter_constructs_stream_wrapper(): + stream = _FakeSyncStream() + wrapper = MessagesStreamManagerWrapper( + manager=_FakeSyncManager(stream=stream), + handler=_make_handler(), + invocation=_make_invocation(), + capture_content=False, + ) + + with wrapper as result: + assert isinstance(result, MessagesStreamWrapper) + assert result.stream is stream + assert wrapper._stream_wrapper is result + + +def test_sync_manager_exit_forwards_exception_to_stream_wrapper(): + wrapper = MessagesStreamManagerWrapper( + manager=_FakeSyncManager(stream=SimpleNamespace(), suppressed=False), + handler=SimpleNamespace(), + invocation=_make_invocation(), + capture_content=False, + ) + stream_wrapper = _FakeStreamWrapper() + wrapper._stream_wrapper = stream_wrapper + + error = ValueError("boom") + result = wrapper.__exit__(ValueError, error, None) + + assert result is False + assert wrapper._manager.exit_args == (ValueError, error, None) + assert stream_wrapper.exit_args == (ValueError, error, None) + + +def test_sync_manager_exit_uses_none_exception_when_manager_suppresses(): + wrapper = MessagesStreamManagerWrapper( + manager=_FakeSyncManager(stream=SimpleNamespace(), suppressed=True), + handler=SimpleNamespace(), + invocation=_make_invocation(), + capture_content=False, + ) + stream_wrapper = _FakeStreamWrapper() + wrapper._stream_wrapper = stream_wrapper + + error = RuntimeError("ignored") + result = wrapper.__exit__(RuntimeError, error, None) + + assert result is True + assert wrapper._manager.exit_args == (RuntimeError, error, None) + assert stream_wrapper.exit_args == (None, None, None) + + +def test_sync_manager_exit_still_finalizes_stream_wrapper_when_manager_raises(): + manager_error = RuntimeError("manager failure") + wrapper = MessagesStreamManagerWrapper( + manager=_FakeSyncManager( + stream=SimpleNamespace(), + suppressed=False, + exit_error=manager_error, + ), + handler=SimpleNamespace(), + invocation=_make_invocation(), + capture_content=False, + ) + stream_wrapper = _FakeStreamWrapper() + wrapper._stream_wrapper = stream_wrapper + + error = ValueError("outer") + with pytest.raises(RuntimeError, match="manager failure"): + wrapper.__exit__(ValueError, error, None) + + assert wrapper._manager.exit_args == (ValueError, error, None) + assert stream_wrapper.exit_args == (ValueError, error, None) + + +@pytest.mark.asyncio +async def test_async_stream_wrapper_exit_closes_without_exception(): + stream = _FakeAsyncStream() + wrapper = _make_async_stream_wrapper(stream) + stopped = [] + + wrapper._stop = lambda: stopped.append(True) + + result = await wrapper.__aexit__(None, None, None) + + assert result is False + assert stream.close_calls == 1 + assert stopped == [True] + + +@pytest.mark.asyncio +async def test_async_stream_wrapper_exit_fails_and_closes_on_exception(): + stream = _FakeAsyncStream() + wrapper = _make_async_stream_wrapper(stream) + stopped = [] + failures = [] + + wrapper._stop = lambda: stopped.append(True) + wrapper._fail = lambda message, error_type: failures.append( + (message, error_type) + ) + + error = ValueError("boom") + result = await wrapper.__aexit__(ValueError, error, None) + + assert result is False + assert stream.close_calls == 1 + assert stopped == [True] + assert failures == [("boom", ValueError)] + + +@pytest.mark.asyncio +async def test_async_stream_wrapper_close_uses_close_and_stops(): + stream = _FakeAsyncStream() + wrapper = _make_async_stream_wrapper(stream) + stopped = [] + + wrapper._stop = lambda: stopped.append(True) + + await wrapper.close() + + assert stream.close_calls == 1 + assert stopped == [True] + + +@pytest.mark.asyncio +async def test_async_stream_wrapper_processes_events_and_stops_on_completion(): + event = SimpleNamespace(type="message_start") + stream = _FakeAsyncStream(events=[event]) + wrapper = _make_async_stream_wrapper(stream) + processed = [] + stopped = [] + + wrapper._process_chunk = processed.append + wrapper._stop = lambda: stopped.append(True) + + result = await wrapper.__anext__() + + assert result is event + assert processed == [event] + + with pytest.raises(StopAsyncIteration): + await wrapper.__anext__() + + assert stopped == [True] + + +@pytest.mark.asyncio +async def test_async_stream_wrapper_fails_and_reraises_stream_errors(): + error = ValueError("boom") + stream = _FakeAsyncStream(error=error) + wrapper = _make_async_stream_wrapper(stream) + failures = [] + + wrapper._fail = lambda message, error_type: failures.append( + (message, error_type) + ) + + with pytest.raises(ValueError, match="boom"): + await wrapper.__anext__() + + assert failures == [("boom", ValueError)] + + +@pytest.mark.asyncio +async def test_async_stream_wrapper_preserves_stream_helper_methods(): + stream = _FakeAsyncStream() + wrapper = _make_async_stream_wrapper(stream) + + result = await wrapper.get_final_message() + + assert result is stream.final_message + assert wrapper.response.request_id == "req_async" + + +@pytest.mark.asyncio +async def test_async_stream_response_aclose_finalizes_wrapper(): + stream = _FakeAsyncStream() + wrapper = _make_async_stream_wrapper(stream) + stopped = [] + + wrapper._stop = lambda: stopped.append(True) + + await wrapper.response.aclose() + + assert stream.response.aclose_calls == 1 + assert stopped == [True] + + +@pytest.mark.asyncio +async def test_async_manager_enter_constructs_async_stream_wrapper(): + stream = _FakeAsyncStream() + wrapper = AsyncMessagesStreamManagerWrapper( + manager=_FakeAsyncManager(stream=stream), + handler=_make_handler(), + invocation=_make_invocation(), + capture_content=False, + ) + + async with wrapper as result: + assert isinstance(result, AsyncMessagesStreamWrapper) + assert result.stream is stream + assert wrapper._stream_wrapper is result + + +@pytest.mark.asyncio +async def test_async_manager_exit_forwards_exception_to_stream_wrapper(): + wrapper = AsyncMessagesStreamManagerWrapper( + manager=_FakeAsyncManager(stream=SimpleNamespace(), suppressed=False), + handler=SimpleNamespace(), + invocation=_make_invocation(), + capture_content=False, + ) + stream_wrapper = _FakeAsyncStreamWrapper() + wrapper._stream_wrapper = stream_wrapper + + error = ValueError("boom") + result = await wrapper.__aexit__(ValueError, error, None) + + assert result is False + assert wrapper._manager.exit_args == (ValueError, error, None) + assert stream_wrapper.exit_args == (ValueError, error, None) + + +@pytest.mark.asyncio +async def test_async_manager_exit_uses_none_exception_when_manager_suppresses(): + wrapper = AsyncMessagesStreamManagerWrapper( + manager=_FakeAsyncManager(stream=SimpleNamespace(), suppressed=True), + handler=SimpleNamespace(), + invocation=_make_invocation(), + capture_content=False, + ) + stream_wrapper = _FakeAsyncStreamWrapper() + wrapper._stream_wrapper = stream_wrapper + + error = RuntimeError("ignored") + result = await wrapper.__aexit__(RuntimeError, error, None) + + assert result is True + assert wrapper._manager.exit_args == (RuntimeError, error, None) + assert stream_wrapper.exit_args == (None, None, None) + + +@pytest.mark.asyncio +async def test_async_manager_exit_still_finalizes_stream_wrapper_when_manager_raises(): + manager_error = RuntimeError("manager failure") + wrapper = AsyncMessagesStreamManagerWrapper( + manager=_FakeAsyncManager( + stream=SimpleNamespace(), + suppressed=False, + exit_error=manager_error, + ), + handler=SimpleNamespace(), + invocation=_make_invocation(), + capture_content=False, + ) + stream_wrapper = _FakeAsyncStreamWrapper() + wrapper._stream_wrapper = stream_wrapper + + error = ValueError("outer") + with pytest.raises(RuntimeError, match="manager failure"): + await wrapper.__aexit__(ValueError, error, None) + + assert wrapper._manager.exit_args == (ValueError, error, None) + assert stream_wrapper.exit_args == (ValueError, error, None) From ae6214d4114284c7bc20c16c989574e4d3bde1a1 Mon Sep 17 00:00:00 2001 From: eternalcuriouslearner Date: Wed, 18 Mar 2026 00:41:35 -0400 Subject: [PATCH 2/8] wip: replacing the invocation response attributes with new attributes function. --- .../instrumentation/anthropic/wrappers.py | 35 ++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py index 9dbc2bd1cb..b46a58d538 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py @@ -58,6 +58,24 @@ ResponseT = TypeVar("ResponseT") +def _set_response_attributes( + invocation: LLMInvocation, + wrapper: "MessagesStreamWrapper", +) -> None: + set_invocation_stream_response_attributes( + invocation, + response_model=wrapper._response_model, + response_id=wrapper._response_id, + stop_reason=wrapper._stop_reason, + input_tokens=wrapper._input_tokens, + output_tokens=wrapper._output_tokens, + cache_creation_input_tokens=wrapper._cache_creation_input_tokens, + cache_read_input_tokens=wrapper._cache_read_input_tokens, + capture_content=wrapper._capture_content, + content_blocks=wrapper._content_blocks, + ) + + class _ResponseProxy(Generic[ResponseT]): def __init__(self, response: ResponseT, finalize: Callable[[], None]): self._response = response @@ -188,26 +206,11 @@ def _safe_instrumentation( exc_info=True, ) - def _set_invocation_response_attributes(self) -> None: - """Extract accumulated stream state into the invocation.""" - set_invocation_stream_response_attributes( - self.invocation, - response_model=self._response_model, - response_id=self._response_id, - stop_reason=self._stop_reason, - input_tokens=self._input_tokens, - output_tokens=self._output_tokens, - cache_creation_input_tokens=self._cache_creation_input_tokens, - cache_read_input_tokens=self._cache_read_input_tokens, - capture_content=self._capture_content, - content_blocks=self._content_blocks, - ) - def _stop(self) -> None: if self._finalized: return with self._safe_instrumentation("response attribute extraction"): - self._set_invocation_response_attributes() + _set_response_attributes(self.invocation, self) with self._safe_instrumentation("stop_llm"): self.handler.stop_llm(self.invocation) self._finalized = True From 25766e1e5c2c91f79b982edea85ab67f04d27112 Mon Sep 17 00:00:00 2001 From: eternalcuriouslearner Date: Wed, 18 Mar 2026 00:49:11 -0400 Subject: [PATCH 3/8] wip: cleaning up the function. --- .../tests/test_sync_messages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/tests/test_sync_messages.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/tests/test_sync_messages.py index 6bce6f2e16..a073f3b569 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/tests/test_sync_messages.py +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/tests/test_sync_messages.py @@ -826,7 +826,7 @@ def __getattr__(self, name): return getattr(self._inner, name) monkeypatch.setattr( - stream, "_stream", ErrorInjectingStreamDelegate(stream._stream) + stream, "stream", ErrorInjectingStreamDelegate(stream.stream) ) with pytest.raises( From 73835165a1dfc6681ab09052ae78c47fd1381380 Mon Sep 17 00:00:00 2001 From: eternalcuriouslearner Date: Wed, 18 Mar 2026 01:09:59 -0400 Subject: [PATCH 4/8] polish: cleaning up the response attribution using event accumulation method. --- .../anthropic/messages_extractors.py | 52 +--------- .../instrumentation/anthropic/wrappers.py | 94 +++++-------------- 2 files changed, 24 insertions(+), 122 deletions(-) diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/messages_extractors.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/messages_extractors.py index a69f9051a5..aa98a48b9a 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/messages_extractors.py +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/messages_extractors.py @@ -38,7 +38,6 @@ from .utils import ( convert_content_to_parts, normalize_finish_reason, - stream_block_state_to_part, ) if TYPE_CHECKING: @@ -57,8 +56,6 @@ Usage, ) - from .utils import StreamBlockState - @dataclass class MessageRequestParams: @@ -157,7 +154,7 @@ def get_output_messages_from_message( ] -def set_invocation_message_response_attributes( +def set_invocation_response_attributes( invocation: LLMInvocation, message: Message | None, capture_content: bool, @@ -192,53 +189,6 @@ def set_invocation_message_response_attributes( invocation.output_messages = get_output_messages_from_message(message) -def set_invocation_stream_response_attributes( - invocation: LLMInvocation, - *, - response_model: str | None, - response_id: str | None, - stop_reason: str | None, - input_tokens: int | None, - output_tokens: int | None, - cache_creation_input_tokens: int | None, - cache_read_input_tokens: int | None, - capture_content: bool, - content_blocks: dict[int, "StreamBlockState"], -) -> None: - if response_model: - invocation.response_model_name = response_model - if response_id: - invocation.response_id = response_id - if stop_reason: - invocation.finish_reasons = [stop_reason] - if input_tokens is not None: - invocation.input_tokens = input_tokens - if output_tokens is not None: - invocation.output_tokens = output_tokens - if cache_creation_input_tokens is not None: - invocation.attributes[ - GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS - ] = cache_creation_input_tokens - if cache_read_input_tokens is not None: - invocation.attributes[ - GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS - ] = cache_read_input_tokens - - if capture_content and content_blocks: - parts: list[MessagePart] = [] - for index in sorted(content_blocks): - part = stream_block_state_to_part(content_blocks[index]) - if part is not None: - parts.append(part) - invocation.output_messages = [ - OutputMessage( - role="assistant", - parts=parts, - finish_reason=stop_reason or "", - ) - ] - - def extract_params( # pylint: disable=too-many-locals *, max_tokens: int | None = None, diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py index b46a58d538..4f2977d885 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py @@ -25,17 +25,14 @@ LLMInvocation, ) -from .messages_extractors import ( - extract_usage_tokens, - set_invocation_message_response_attributes, - set_invocation_stream_response_attributes, -) -from .utils import ( - StreamBlockState, - create_stream_block_state, - normalize_finish_reason, - update_stream_block_state, -) +from .messages_extractors import set_invocation_response_attributes + +try: + from anthropic.lib.streaming._messages import ( # pylint: disable=no-name-in-module + accumulate_event, + ) +except ImportError: + accumulate_event = None if TYPE_CHECKING: from anthropic._streaming import AsyncStream, Stream @@ -48,9 +45,7 @@ ) from anthropic.types import ( Message, - MessageDeltaUsage, RawMessageStreamEvent, - Usage, ) @@ -60,20 +55,10 @@ def _set_response_attributes( invocation: LLMInvocation, - wrapper: "MessagesStreamWrapper", + result: "Message | None", + capture_content: bool, ) -> None: - set_invocation_stream_response_attributes( - invocation, - response_model=wrapper._response_model, - response_id=wrapper._response_id, - stop_reason=wrapper._stop_reason, - input_tokens=wrapper._input_tokens, - output_tokens=wrapper._output_tokens, - cache_creation_input_tokens=wrapper._cache_creation_input_tokens, - cache_read_input_tokens=wrapper._cache_read_input_tokens, - capture_content=wrapper._capture_content, - content_blocks=wrapper._content_blocks, - ) + set_invocation_response_attributes(invocation, result, capture_content) class _ResponseProxy(Generic[ResponseT]): @@ -121,7 +106,7 @@ def __init__(self, message: Message, capture_content: bool): def extract_into(self, invocation: LLMInvocation) -> None: """Extract response data into the invocation.""" - set_invocation_message_response_attributes( + set_invocation_response_attributes( invocation, self._message, self._capture_content ) @@ -144,53 +129,18 @@ def __init__( self.stream = stream self.handler = handler self.invocation = invocation - self._response_id: Optional[str] = None - self._response_model: Optional[str] = None - self._stop_reason: Optional[str] = None - self._input_tokens: Optional[int] = None - self._output_tokens: Optional[int] = None - self._cache_creation_input_tokens: Optional[int] = None - self._cache_read_input_tokens: Optional[int] = None + self._message: "Message | None" = None self._capture_content = capture_content - self._content_blocks: dict[int, StreamBlockState] = {} self._finalized = False - def _update_usage(self, usage: Usage | MessageDeltaUsage | None) -> None: - tokens = extract_usage_tokens(usage) - if tokens.input_tokens is not None: - self._input_tokens = tokens.input_tokens - if tokens.output_tokens is not None: - self._output_tokens = tokens.output_tokens - if tokens.cache_creation_input_tokens is not None: - self._cache_creation_input_tokens = ( - tokens.cache_creation_input_tokens - ) - if tokens.cache_read_input_tokens is not None: - self._cache_read_input_tokens = tokens.cache_read_input_tokens - def _process_chunk(self, chunk: RawMessageStreamEvent) -> None: - """Extract telemetry data from a streaming chunk.""" - if chunk.type == "message_start": - message = chunk.message - if message.id: - self._response_id = message.id - if message.model: - self._response_model = message.model - self._update_usage(message.usage) - elif chunk.type == "message_delta": - if chunk.delta.stop_reason: - self._stop_reason = normalize_finish_reason( - chunk.delta.stop_reason - ) - self._update_usage(chunk.usage) - elif self._capture_content and chunk.type == "content_block_start": - self._content_blocks[chunk.index] = create_stream_block_state( - chunk.content_block - ) - elif self._capture_content and chunk.type == "content_block_delta": - block = self._content_blocks.get(chunk.index) - if block is not None: - update_stream_block_state(block, chunk.delta) + """Accumulate a final message snapshot from a streaming chunk.""" + if accumulate_event is None: + return + self._message = accumulate_event( + event=chunk, + current_snapshot=self._message, + ) @staticmethod @contextmanager @@ -210,7 +160,9 @@ def _stop(self) -> None: if self._finalized: return with self._safe_instrumentation("response attribute extraction"): - _set_response_attributes(self.invocation, self) + _set_response_attributes( + self.invocation, self._message, self._capture_content + ) with self._safe_instrumentation("stop_llm"): self.handler.stop_llm(self.invocation) self._finalized = True From 693b9420f1eef1962319a51f95e438d3105c2025 Mon Sep 17 00:00:00 2001 From: eternalcuriouslearner Date: Wed, 18 Mar 2026 18:42:04 -0400 Subject: [PATCH 5/8] wip: rearranging the code in anthropic wrappers. --- .../instrumentation/anthropic/wrappers.py | 134 +++++++++--------- 1 file changed, 67 insertions(+), 67 deletions(-) diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py index 4f2977d885..348d37995d 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py @@ -133,28 +133,55 @@ def __init__( self._capture_content = capture_content self._finalized = False - def _process_chunk(self, chunk: RawMessageStreamEvent) -> None: - """Accumulate a final message snapshot from a streaming chunk.""" - if accumulate_event is None: - return - self._message = accumulate_event( - event=chunk, - current_snapshot=self._message, - ) + def __enter__(self) -> MessagesStreamWrapper: + return self - @staticmethod - @contextmanager - def _safe_instrumentation( - context: str, - ) -> Generator[None, None, None]: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool: try: - yield - except Exception: # pylint: disable=broad-exception-caught - _logger.debug( - "Anthropic MessagesStreamWrapper instrumentation error in %s", - context, - exc_info=True, - ) + if exc_type is not None: + self._fail( + str(exc_val), type(exc_val) if exc_val else Exception + ) + finally: + self.close() + return False + + def close(self) -> None: + try: + self.stream.close() + finally: + self._stop() + + def __iter__(self) -> MessagesStreamWrapper: + return self + + def __next__(self) -> "RawMessageStreamEvent | MessageStreamEvent": + try: + chunk = next(self.stream) + except StopIteration: + self._stop() + raise + except Exception as exc: + self._fail(str(exc), type(exc)) + raise + with self._safe_instrumentation("stream chunk processing"): + self._process_chunk(chunk) + return chunk + + def __getattr__(self, name: str) -> object: + return getattr(self.stream, name) + + @property + def response(self): + response = getattr(self.stream, "response", None) + if response is None: + return None + return _ResponseProxy(response, self._stop) def _stop(self) -> None: if self._finalized: @@ -176,55 +203,28 @@ def _fail(self, message: str, error_type: type[BaseException]) -> None: ) self._finalized = True - def __iter__(self) -> MessagesStreamWrapper: - return self - - def __getattr__(self, name: str) -> object: - return getattr(self.stream, name) - - @property - def response(self): - response = getattr(self.stream, "response", None) - if response is None: - return None - return _ResponseProxy(response, self._stop) - - def __next__(self) -> "RawMessageStreamEvent | MessageStreamEvent": - try: - chunk = next(self.stream) - except StopIteration: - self._stop() - raise - except Exception as exc: - self._fail(str(exc), type(exc)) - raise - with self._safe_instrumentation("stream chunk processing"): - self._process_chunk(chunk) - return chunk - - def __enter__(self) -> MessagesStreamWrapper: - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> bool: + @staticmethod + @contextmanager + def _safe_instrumentation( + context: str, + ) -> Generator[None, None, None]: try: - if exc_type is not None: - self._fail( - str(exc_val), type(exc_val) if exc_val else Exception - ) - finally: - self.close() - return False + yield + except Exception: # pylint: disable=broad-exception-caught + _logger.debug( + "Anthropic MessagesStreamWrapper instrumentation error in %s", + context, + exc_info=True, + ) - def close(self) -> None: - try: - self.stream.close() - finally: - self._stop() + def _process_chunk(self, chunk: RawMessageStreamEvent) -> None: + """Accumulate a final message snapshot from a streaming chunk.""" + if accumulate_event is None: + return + self._message = accumulate_event( + event=chunk, + current_snapshot=self._message, + ) class AsyncMessagesStreamWrapper(MessagesStreamWrapper): From 4e3243d243587bfbb8489e30b48295781cc71aa5 Mon Sep 17 00:00:00 2001 From: eternalcuriouslearner Date: Wed, 18 Mar 2026 18:58:11 -0400 Subject: [PATCH 6/8] polish: added a changelog. --- .../opentelemetry-instrumentation-anthropic/CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/CHANGELOG.md b/instrumentation-genai/opentelemetry-instrumentation-anthropic/CHANGELOG.md index 5cf210fbd1..dfa9e01631 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/CHANGELOG.md +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add async Anthropic message stream wrappers and manager wrappers, with wrapper + tests ([#4346](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4346)) + - `AsyncMessagesStreamWrapper` for async message stream telemetry + - `AsyncMessagesStreamManagerWrapper` for async `Messages.stream()` telemetry - Add sync streaming support for `Messages.create(stream=True)` and `Messages.stream()` ([#4155](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4155)) - `StreamWrapper` for handling `Messages.create(stream=True)` telemetry @@ -22,4 +26,3 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Captures response attributes: `gen_ai.response.id`, `gen_ai.response.model`, `gen_ai.response.finish_reasons`, `gen_ai.usage.input_tokens`, `gen_ai.usage.output_tokens` - Error handling with `error.type` attribute - Minimum supported anthropic version is 0.16.0 (SDK uses modern `anthropic.resources.messages` module structure introduced in this version) - From 91c33429c86cb4167d33a61bd413e3b92205ec3a Mon Sep 17 00:00:00 2001 From: eternalcuriouslearner Date: Wed, 18 Mar 2026 19:27:02 -0400 Subject: [PATCH 7/8] wip: fixing lint, typecheck and precommit failures. --- .../anthropic/messages_extractors.py | 6 +- .../instrumentation/anthropic/patch.py | 15 +- .../instrumentation/anthropic/wrappers.py | 167 +++++++++++++----- 3 files changed, 132 insertions(+), 56 deletions(-) diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/messages_extractors.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/messages_extractors.py index aa98a48b9a..6f8f786729 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/messages_extractors.py +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/messages_extractors.py @@ -177,9 +177,9 @@ def set_invocation_response_attributes( invocation.input_tokens = tokens.input_tokens invocation.output_tokens = tokens.output_tokens if tokens.cache_creation_input_tokens is not None: - invocation.attributes[ - GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS - ] = tokens.cache_creation_input_tokens + invocation.attributes[GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS] = ( + tokens.cache_creation_input_tokens + ) if tokens.cache_read_input_tokens is not None: invocation.attributes[GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS] = ( tokens.cache_read_input_tokens diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/patch.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/patch.py index 2392dfcd08..d92a6d4e46 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/patch.py +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/patch.py @@ -14,6 +14,8 @@ """Patching functions for Anthropic instrumentation.""" +from __future__ import annotations + import logging from typing import TYPE_CHECKING, Any, Callable, Union, cast @@ -56,7 +58,7 @@ def messages_create( Union[ "AnthropicMessage", "AnthropicStream[RawMessageStreamEvent]", - MessagesStreamWrapper, + MessagesStreamWrapper[RawMessageStreamEvent], ], ]: """Wrap the `create` method of the `Messages` class to trace it.""" @@ -76,7 +78,7 @@ def traced_method( ) -> Union[ "AnthropicMessage", "AnthropicStream[RawMessageStreamEvent]", - MessagesStreamWrapper, + MessagesStreamWrapper[RawMessageStreamEvent], ]: params = extract_params(*args, **kwargs) attributes = get_llm_request_attributes(params, instance) @@ -121,13 +123,6 @@ def traced_method( raise return cast( - Callable[ - ..., - Union[ - "AnthropicMessage", - "AnthropicStream[RawMessageStreamEvent]", - MessagesStreamWrapper, - ], - ], + 'Callable[..., Union["AnthropicMessage", "AnthropicStream[RawMessageStreamEvent]", MessagesStreamWrapper[RawMessageStreamEvent]]]', traced_method, ) diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py index 348d37995d..e04faa5f5a 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py @@ -17,7 +17,17 @@ import logging from contextlib import AsyncExitStack, ExitStack, contextmanager from types import TracebackType -from typing import TYPE_CHECKING, Callable, Generator, Generic, Iterator, Optional, TypeVar +from typing import ( + TYPE_CHECKING, + AsyncIterator, + Callable, + Generator, + Generic, + Iterator, + Protocol, + TypeVar, + cast, +) from opentelemetry.util.genai.handler import TelemetryHandler from opentelemetry.util.genai.types import ( @@ -29,20 +39,19 @@ try: from anthropic.lib.streaming._messages import ( # pylint: disable=no-name-in-module - accumulate_event, + accumulate_event as _sdk_accumulate_event, ) except ImportError: - accumulate_event = None + _sdk_accumulate_event = None if TYPE_CHECKING: - from anthropic._streaming import AsyncStream, Stream from anthropic.lib.streaming._messages import ( # pylint: disable=no-name-in-module - AsyncMessageStream, AsyncMessageStreamManager, - MessageStream, - MessageStreamEvent, MessageStreamManager, ) + from anthropic.lib.streaming._types import ( # pylint: disable=no-name-in-module + MessageStreamEvent, + ) from anthropic.types import ( Message, RawMessageStreamEvent, @@ -50,7 +59,48 @@ _logger = logging.getLogger(__name__) -ResponseT = TypeVar("ResponseT") +SyncResponseT = TypeVar("SyncResponseT", bound="_SupportsClose") +AsyncResponseT = TypeVar("AsyncResponseT", bound="_SupportsAclose") +StreamEventT = TypeVar( + "StreamEventT", "RawMessageStreamEvent", "MessageStreamEvent" +) +StreamEventT_co = TypeVar( + "StreamEventT_co", + "RawMessageStreamEvent", + "MessageStreamEvent", + covariant=True, +) +accumulate_event = cast("Callable[..., Message] | None", _sdk_accumulate_event) + + +class _SupportsClose(Protocol): + def close(self) -> None: ... + + +class _SupportsAclose(_SupportsClose, Protocol): + async def aclose(self) -> None: ... + + +class _SyncStream(Protocol[StreamEventT_co]): + @property + def response(self) -> _SupportsClose: ... + + def __iter__(self) -> Iterator[StreamEventT_co]: ... + + def __next__(self) -> StreamEventT_co: ... + + def close(self) -> None: ... + + +class _AsyncStream(Protocol[StreamEventT_co]): + @property + def response(self) -> _SupportsAclose: ... + + def __aiter__(self) -> AsyncIterator[StreamEventT_co]: ... + + async def __anext__(self) -> StreamEventT_co: ... + + async def close(self) -> None: ... def _set_response_attributes( @@ -61,8 +111,8 @@ def _set_response_attributes( set_invocation_response_attributes(invocation, result, capture_content) -class _ResponseProxy(Generic[ResponseT]): - def __init__(self, response: ResponseT, finalize: Callable[[], None]): +class _ResponseProxy(Generic[SyncResponseT]): + def __init__(self, response: SyncResponseT, finalize: Callable[[], None]): self._response = response self._finalize = finalize @@ -76,8 +126,8 @@ def __getattr__(self, name: str): return getattr(self._response, name) -class _AsyncResponseProxy(Generic[ResponseT]): - def __init__(self, response: ResponseT, finalize: Callable[[], None]): +class _AsyncResponseProxy(Generic[AsyncResponseT]): + def __init__(self, response: AsyncResponseT, finalize: Callable[[], None]): self._response = response self._finalize = finalize @@ -116,12 +166,12 @@ def message(self) -> Message: return self._message -class MessagesStreamWrapper(Iterator["RawMessageStreamEvent"]): +class MessagesStreamWrapper(Generic[StreamEventT], Iterator[StreamEventT]): """Wrapper for Anthropic Stream that handles telemetry.""" def __init__( self, - stream: Stream[RawMessageStreamEvent], + stream: _SyncStream[StreamEventT], handler: TelemetryHandler, invocation: LLMInvocation, capture_content: bool, @@ -133,7 +183,7 @@ def __init__( self._capture_content = capture_content self._finalized = False - def __enter__(self) -> MessagesStreamWrapper: + def __enter__(self) -> "MessagesStreamWrapper[StreamEventT]": return self def __exit__( @@ -157,10 +207,10 @@ def close(self) -> None: finally: self._stop() - def __iter__(self) -> MessagesStreamWrapper: + def __iter__(self) -> "MessagesStreamWrapper[StreamEventT]": return self - def __next__(self) -> "RawMessageStreamEvent | MessageStreamEvent": + def __next__(self) -> StreamEventT: try: chunk = next(self.stream) except StopIteration: @@ -177,11 +227,10 @@ def __getattr__(self, name: str) -> object: return getattr(self.stream, name) @property - def response(self): - response = getattr(self.stream, "response", None) - if response is None: - return None - return _ResponseProxy(response, self._stop) + def response( + self, + ) -> "_ResponseProxy[_SupportsClose] | _AsyncResponseProxy[_SupportsAclose] | None": + return _ResponseProxy(self.stream.response, self._stop) def _stop(self) -> None: if self._finalized: @@ -217,22 +266,43 @@ def _safe_instrumentation( exc_info=True, ) - def _process_chunk(self, chunk: RawMessageStreamEvent) -> None: + def _process_chunk(self, chunk: StreamEventT) -> None: """Accumulate a final message snapshot from a streaming chunk.""" + snapshot = cast( + "Message | None", + getattr(self.stream, "current_message_snapshot", None), + ) + if snapshot is not None: + self._message = snapshot + return if accumulate_event is None: return self._message = accumulate_event( - event=chunk, + event=cast("RawMessageStreamEvent", chunk), current_snapshot=self._message, ) -class AsyncMessagesStreamWrapper(MessagesStreamWrapper): +class AsyncMessagesStreamWrapper(MessagesStreamWrapper[StreamEventT]): """Wrapper for async Anthropic Stream that handles telemetry.""" - stream: "AsyncStream[RawMessageStreamEvent] | AsyncMessageStream" + stream: _AsyncStream[StreamEventT] + + def __init__( + self, + stream: _AsyncStream[StreamEventT], + handler: TelemetryHandler, + invocation: LLMInvocation, + capture_content: bool, + ): + self.stream = stream + self.handler = handler + self.invocation = invocation + self._message: "Message | None" = None + self._capture_content = capture_content + self._finalized = False - async def __aenter__(self) -> "AsyncMessagesStreamWrapper": + async def __aenter__(self) -> "AsyncMessagesStreamWrapper[StreamEventT]": return self async def __aexit__( @@ -256,17 +326,16 @@ async def close(self) -> None: # type: ignore[override] finally: self._stop() - def __aiter__(self) -> "AsyncMessagesStreamWrapper": + def __aiter__(self) -> "AsyncMessagesStreamWrapper[StreamEventT]": return self @property - def response(self): - response = getattr(self.stream, "response", None) - if response is None: - return None - return _AsyncResponseProxy(response, self._stop) + def response( + self, + ) -> "_ResponseProxy[_SupportsClose] | _AsyncResponseProxy[_SupportsAclose] | None": + return _AsyncResponseProxy(self.stream.response, self._stop) - async def __anext__(self) -> "RawMessageStreamEvent | MessageStreamEvent": + async def __anext__(self) -> StreamEventT: try: chunk = await self.stream.__anext__() except StopAsyncIteration: @@ -294,10 +363,15 @@ def __init__( self._handler = handler self._invocation = invocation self._capture_content = capture_content - self._stream_wrapper: MessagesStreamWrapper | None = None - - def __enter__(self) -> MessagesStreamWrapper: - stream = self._manager.__enter__() + self._stream_wrapper: ( + MessagesStreamWrapper[MessageStreamEvent] | None + ) = None + + def __enter__(self) -> MessagesStreamWrapper[MessageStreamEvent]: + stream = cast( + "_SyncStream[MessageStreamEvent]", + self._manager.__enter__(), + ) self._stream_wrapper = MessagesStreamWrapper( stream, self._handler, @@ -311,7 +385,7 @@ def __exit__( exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> bool: + ) -> bool | None: suppressed = False stream_wrapper = self._stream_wrapper self._stream_wrapper = None @@ -350,10 +424,17 @@ def __init__( self._handler = handler self._invocation = invocation self._capture_content = capture_content - self._stream_wrapper: AsyncMessagesStreamWrapper | None = None + self._stream_wrapper: ( + AsyncMessagesStreamWrapper[MessageStreamEvent] | None + ) = None - async def __aenter__(self) -> AsyncMessagesStreamWrapper: - msg_stream = await self._manager.__aenter__() + async def __aenter__( + self, + ) -> AsyncMessagesStreamWrapper[MessageStreamEvent]: + msg_stream = cast( + "_AsyncStream[MessageStreamEvent]", + await self._manager.__aenter__(), + ) self._stream_wrapper = AsyncMessagesStreamWrapper( msg_stream, self._handler, @@ -367,7 +448,7 @@ async def __aexit__( exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> bool: + ) -> bool | None: suppressed = False stream_wrapper = self._stream_wrapper self._stream_wrapper = None From 963366a42dbcdd618c3430daacea35db263b8cf6 Mon Sep 17 00:00:00 2001 From: eternalcuriouslearner Date: Thu, 19 Mar 2026 00:53:46 -0400 Subject: [PATCH 8/8] wip: keeping anthropic wrapper same as openai wrapper. --- .../instrumentation/anthropic/patch.py | 6 +- .../instrumentation/anthropic/wrappers.py | 156 +++++++----------- 2 files changed, 62 insertions(+), 100 deletions(-) diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/patch.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/patch.py index d92a6d4e46..4dc8b55d38 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/patch.py +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/patch.py @@ -58,7 +58,7 @@ def messages_create( Union[ "AnthropicMessage", "AnthropicStream[RawMessageStreamEvent]", - MessagesStreamWrapper[RawMessageStreamEvent], + MessagesStreamWrapper[None], ], ]: """Wrap the `create` method of the `Messages` class to trace it.""" @@ -78,7 +78,7 @@ def traced_method( ) -> Union[ "AnthropicMessage", "AnthropicStream[RawMessageStreamEvent]", - MessagesStreamWrapper[RawMessageStreamEvent], + MessagesStreamWrapper[None], ]: params = extract_params(*args, **kwargs) attributes = get_llm_request_attributes(params, instance) @@ -123,6 +123,6 @@ def traced_method( raise return cast( - 'Callable[..., Union["AnthropicMessage", "AnthropicStream[RawMessageStreamEvent]", MessagesStreamWrapper[RawMessageStreamEvent]]]', + 'Callable[..., Union["AnthropicMessage", "AnthropicStream[RawMessageStreamEvent]", MessagesStreamWrapper[None]]]', traced_method, ) diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py index e04faa5f5a..52e2e68582 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py @@ -19,12 +19,11 @@ from types import TracebackType from typing import ( TYPE_CHECKING, - AsyncIterator, + Any, Callable, Generator, Generic, Iterator, - Protocol, TypeVar, cast, ) @@ -45,64 +44,29 @@ _sdk_accumulate_event = None if TYPE_CHECKING: + from anthropic._streaming import AsyncStream, Stream from anthropic.lib.streaming._messages import ( # pylint: disable=no-name-in-module + AsyncMessageStream, AsyncMessageStreamManager, + MessageStream, MessageStreamManager, ) from anthropic.lib.streaming._types import ( # pylint: disable=no-name-in-module - MessageStreamEvent, + ParsedMessageStreamEvent, ) from anthropic.types import ( Message, RawMessageStreamEvent, ) + from anthropic.types.parsed_message import ParsedMessage _logger = logging.getLogger(__name__) -SyncResponseT = TypeVar("SyncResponseT", bound="_SupportsClose") -AsyncResponseT = TypeVar("AsyncResponseT", bound="_SupportsAclose") -StreamEventT = TypeVar( - "StreamEventT", "RawMessageStreamEvent", "MessageStreamEvent" -) -StreamEventT_co = TypeVar( - "StreamEventT_co", - "RawMessageStreamEvent", - "MessageStreamEvent", - covariant=True, -) +ResponseT = TypeVar("ResponseT") +ResponseFormatT = TypeVar("ResponseFormatT") accumulate_event = cast("Callable[..., Message] | None", _sdk_accumulate_event) -class _SupportsClose(Protocol): - def close(self) -> None: ... - - -class _SupportsAclose(_SupportsClose, Protocol): - async def aclose(self) -> None: ... - - -class _SyncStream(Protocol[StreamEventT_co]): - @property - def response(self) -> _SupportsClose: ... - - def __iter__(self) -> Iterator[StreamEventT_co]: ... - - def __next__(self) -> StreamEventT_co: ... - - def close(self) -> None: ... - - -class _AsyncStream(Protocol[StreamEventT_co]): - @property - def response(self) -> _SupportsAclose: ... - - def __aiter__(self) -> AsyncIterator[StreamEventT_co]: ... - - async def __anext__(self) -> StreamEventT_co: ... - - async def close(self) -> None: ... - - def _set_response_attributes( invocation: LLMInvocation, result: "Message | None", @@ -111,9 +75,9 @@ def _set_response_attributes( set_invocation_response_attributes(invocation, result, capture_content) -class _ResponseProxy(Generic[SyncResponseT]): - def __init__(self, response: SyncResponseT, finalize: Callable[[], None]): - self._response = response +class _ResponseProxy(Generic[ResponseT]): + def __init__(self, response: ResponseT, finalize: Callable[[], None]): + self._response: Any = response self._finalize = finalize def close(self) -> None: @@ -126,17 +90,11 @@ def __getattr__(self, name: str): return getattr(self._response, name) -class _AsyncResponseProxy(Generic[AsyncResponseT]): - def __init__(self, response: AsyncResponseT, finalize: Callable[[], None]): - self._response = response +class _AsyncResponseProxy(Generic[ResponseT]): + def __init__(self, response: ResponseT, finalize: Callable[[], None]): + self._response: Any = response self._finalize = finalize - def close(self) -> None: - try: - self._response.close() - finally: - self._finalize() - async def aclose(self) -> None: try: await self._response.aclose() @@ -166,12 +124,17 @@ def message(self) -> Message: return self._message -class MessagesStreamWrapper(Generic[StreamEventT], Iterator[StreamEventT]): +class MessagesStreamWrapper( + Generic[ResponseFormatT], + Iterator[ + "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]" + ], +): """Wrapper for Anthropic Stream that handles telemetry.""" def __init__( self, - stream: _SyncStream[StreamEventT], + stream: "Stream[RawMessageStreamEvent] | MessageStream[ResponseFormatT]", handler: TelemetryHandler, invocation: LLMInvocation, capture_content: bool, @@ -179,11 +142,11 @@ def __init__( self.stream = stream self.handler = handler self.invocation = invocation - self._message: "Message | None" = None + self._message: "Message | ParsedMessage[ResponseFormatT] | None" = None self._capture_content = capture_content self._finalized = False - def __enter__(self) -> "MessagesStreamWrapper[StreamEventT]": + def __enter__(self) -> "MessagesStreamWrapper[ResponseFormatT]": return self def __exit__( @@ -207,10 +170,12 @@ def close(self) -> None: finally: self._stop() - def __iter__(self) -> "MessagesStreamWrapper[StreamEventT]": + def __iter__(self) -> "MessagesStreamWrapper[ResponseFormatT]": return self - def __next__(self) -> StreamEventT: + def __next__( + self, + ) -> "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]": try: chunk = next(self.stream) except StopIteration: @@ -227,9 +192,7 @@ def __getattr__(self, name: str) -> object: return getattr(self.stream, name) @property - def response( - self, - ) -> "_ResponseProxy[_SupportsClose] | _AsyncResponseProxy[_SupportsAclose] | None": + def response(self): return _ResponseProxy(self.stream.response, self._stop) def _stop(self) -> None: @@ -266,10 +229,13 @@ def _safe_instrumentation( exc_info=True, ) - def _process_chunk(self, chunk: StreamEventT) -> None: + def _process_chunk( + self, + chunk: "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]", + ) -> None: """Accumulate a final message snapshot from a streaming chunk.""" snapshot = cast( - "Message | None", + "ParsedMessage[ResponseFormatT] | None", getattr(self.stream, "current_message_snapshot", None), ) if snapshot is not None: @@ -279,18 +245,18 @@ def _process_chunk(self, chunk: StreamEventT) -> None: return self._message = accumulate_event( event=cast("RawMessageStreamEvent", chunk), - current_snapshot=self._message, + current_snapshot=cast( + "ParsedMessage[ResponseFormatT] | None", self._message + ), ) -class AsyncMessagesStreamWrapper(MessagesStreamWrapper[StreamEventT]): +class AsyncMessagesStreamWrapper(MessagesStreamWrapper[ResponseFormatT]): """Wrapper for async Anthropic Stream that handles telemetry.""" - stream: _AsyncStream[StreamEventT] - def __init__( self, - stream: _AsyncStream[StreamEventT], + stream: "AsyncStream[RawMessageStreamEvent] | AsyncMessageStream[ResponseFormatT]", handler: TelemetryHandler, invocation: LLMInvocation, capture_content: bool, @@ -298,11 +264,13 @@ def __init__( self.stream = stream self.handler = handler self.invocation = invocation - self._message: "Message | None" = None + self._message: "Message | ParsedMessage[ResponseFormatT] | None" = None self._capture_content = capture_content self._finalized = False - async def __aenter__(self) -> "AsyncMessagesStreamWrapper[StreamEventT]": + async def __aenter__( + self, + ) -> "AsyncMessagesStreamWrapper[ResponseFormatT]": return self async def __aexit__( @@ -326,16 +294,16 @@ async def close(self) -> None: # type: ignore[override] finally: self._stop() - def __aiter__(self) -> "AsyncMessagesStreamWrapper[StreamEventT]": + def __aiter__(self) -> "AsyncMessagesStreamWrapper[ResponseFormatT]": return self @property - def response( - self, - ) -> "_ResponseProxy[_SupportsClose] | _AsyncResponseProxy[_SupportsAclose] | None": + def response(self) -> Any: return _AsyncResponseProxy(self.stream.response, self._stop) - async def __anext__(self) -> StreamEventT: + async def __anext__( + self, + ) -> "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]": try: chunk = await self.stream.__anext__() except StopAsyncIteration: @@ -349,12 +317,12 @@ async def __anext__(self) -> StreamEventT: return chunk -class MessagesStreamManagerWrapper: +class MessagesStreamManagerWrapper(Generic[ResponseFormatT]): """Wrapper for sync Anthropic stream managers.""" def __init__( self, - manager: "MessageStreamManager", + manager: "MessageStreamManager[ResponseFormatT]", handler: TelemetryHandler, invocation: LLMInvocation, capture_content: bool, @@ -363,15 +331,12 @@ def __init__( self._handler = handler self._invocation = invocation self._capture_content = capture_content - self._stream_wrapper: ( - MessagesStreamWrapper[MessageStreamEvent] | None - ) = None - - def __enter__(self) -> MessagesStreamWrapper[MessageStreamEvent]: - stream = cast( - "_SyncStream[MessageStreamEvent]", - self._manager.__enter__(), + self._stream_wrapper: MessagesStreamWrapper[ResponseFormatT] | None = ( + None ) + + def __enter__(self) -> MessagesStreamWrapper[ResponseFormatT]: + stream = self._manager.__enter__() self._stream_wrapper = MessagesStreamWrapper( stream, self._handler, @@ -406,7 +371,7 @@ def __getattr__(self, name: str) -> object: return getattr(self._manager, name) -class AsyncMessagesStreamManagerWrapper: +class AsyncMessagesStreamManagerWrapper(Generic[ResponseFormatT]): """Wrapper for AsyncMessageStreamManager that handles telemetry. Wraps AsyncMessageStreamManager from the Anthropic SDK: @@ -415,7 +380,7 @@ class AsyncMessagesStreamManagerWrapper: def __init__( self, - manager: "AsyncMessageStreamManager", + manager: "AsyncMessageStreamManager[ResponseFormatT]", handler: TelemetryHandler, invocation: LLMInvocation, capture_content: bool, @@ -425,16 +390,13 @@ def __init__( self._invocation = invocation self._capture_content = capture_content self._stream_wrapper: ( - AsyncMessagesStreamWrapper[MessageStreamEvent] | None + AsyncMessagesStreamWrapper[ResponseFormatT] | None ) = None async def __aenter__( self, - ) -> AsyncMessagesStreamWrapper[MessageStreamEvent]: - msg_stream = cast( - "_AsyncStream[MessageStreamEvent]", - await self._manager.__aenter__(), - ) + ) -> AsyncMessagesStreamWrapper[ResponseFormatT]: + msg_stream = await self._manager.__aenter__() self._stream_wrapper = AsyncMessagesStreamWrapper( msg_stream, self._handler,