diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/CHANGELOG.md b/instrumentation-genai/opentelemetry-instrumentation-anthropic/CHANGELOG.md index 5cf210fbd1..dfa9e01631 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/CHANGELOG.md +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add async Anthropic message stream wrappers and manager wrappers, with wrapper + tests ([#4346](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4346)) + - `AsyncMessagesStreamWrapper` for async message stream telemetry + - `AsyncMessagesStreamManagerWrapper` for async `Messages.stream()` telemetry - Add sync streaming support for `Messages.create(stream=True)` and `Messages.stream()` ([#4155](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4155)) - `StreamWrapper` for handling `Messages.create(stream=True)` telemetry @@ -22,4 +26,3 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Captures response attributes: `gen_ai.response.id`, `gen_ai.response.model`, `gen_ai.response.finish_reasons`, `gen_ai.usage.input_tokens`, `gen_ai.usage.output_tokens` - Error handling with `error.type` attribute - Minimum supported anthropic version is 0.16.0 (SDK uses modern `anthropic.resources.messages` module structure introduced in this version) - diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/messages_extractors.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/messages_extractors.py index 907aa08d39..6f8f786729 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/messages_extractors.py +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/messages_extractors.py @@ -29,6 +29,7 @@ ) from opentelemetry.util.genai.types import ( InputMessage, + LLMInvocation, MessagePart, OutputMessage, ) @@ -153,6 +154,41 @@ def get_output_messages_from_message( ] +def set_invocation_response_attributes( + invocation: LLMInvocation, + message: Message | None, + capture_content: bool, +) -> None: + if message is None: + return + + if message.model: + invocation.response_model_name = message.model + + if message.id: + invocation.response_id = message.id + + finish_reason = normalize_finish_reason(message.stop_reason) + if finish_reason: + invocation.finish_reasons = [finish_reason] + + if message.usage: + tokens = extract_usage_tokens(message.usage) + invocation.input_tokens = tokens.input_tokens + invocation.output_tokens = tokens.output_tokens + if tokens.cache_creation_input_tokens is not None: + invocation.attributes[GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS] = ( + tokens.cache_creation_input_tokens + ) + if tokens.cache_read_input_tokens is not None: + invocation.attributes[GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS] = ( + tokens.cache_read_input_tokens + ) + + if capture_content: + invocation.output_messages = get_output_messages_from_message(message) + + def extract_params( # pylint: disable=too-many-locals *, max_tokens: int | None = None, diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/patch.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/patch.py index 2392dfcd08..4dc8b55d38 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/patch.py +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/patch.py @@ -14,6 +14,8 @@ """Patching functions for Anthropic instrumentation.""" +from __future__ import annotations + import logging from typing import TYPE_CHECKING, Any, Callable, Union, cast @@ -56,7 +58,7 @@ def messages_create( Union[ "AnthropicMessage", "AnthropicStream[RawMessageStreamEvent]", - MessagesStreamWrapper, + MessagesStreamWrapper[None], ], ]: """Wrap the `create` method of the `Messages` class to trace it.""" @@ -76,7 +78,7 @@ def traced_method( ) -> Union[ "AnthropicMessage", "AnthropicStream[RawMessageStreamEvent]", - MessagesStreamWrapper, + MessagesStreamWrapper[None], ]: params = extract_params(*args, **kwargs) attributes = get_llm_request_attributes(params, instance) @@ -121,13 +123,6 @@ def traced_method( raise return cast( - Callable[ - ..., - Union[ - "AnthropicMessage", - "AnthropicStream[RawMessageStreamEvent]", - MessagesStreamWrapper, - ], - ], + 'Callable[..., Union["AnthropicMessage", "AnthropicStream[RawMessageStreamEvent]", MessagesStreamWrapper[None]]]', traced_method, ) diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py index 585d70a9f9..52e2e68582 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/src/opentelemetry/instrumentation/anthropic/wrappers.py @@ -15,42 +15,94 @@ from __future__ import annotations import logging +from contextlib import AsyncExitStack, ExitStack, contextmanager from types import TracebackType -from typing import TYPE_CHECKING, Callable, Iterator, Optional +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Generic, + Iterator, + TypeVar, + cast, +) from opentelemetry.util.genai.handler import TelemetryHandler from opentelemetry.util.genai.types import ( Error, LLMInvocation, - MessagePart, - OutputMessage, ) -from .messages_extractors import ( - GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS, - GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS, - extract_usage_tokens, - get_output_messages_from_message, -) -from .utils import ( - StreamBlockState, - create_stream_block_state, - normalize_finish_reason, - stream_block_state_to_part, - update_stream_block_state, -) +from .messages_extractors import set_invocation_response_attributes + +try: + from anthropic.lib.streaming._messages import ( # pylint: disable=no-name-in-module + accumulate_event as _sdk_accumulate_event, + ) +except ImportError: + _sdk_accumulate_event = None if TYPE_CHECKING: - from anthropic._streaming import Stream + from anthropic._streaming import AsyncStream, Stream + from anthropic.lib.streaming._messages import ( # pylint: disable=no-name-in-module + AsyncMessageStream, + AsyncMessageStreamManager, + MessageStream, + MessageStreamManager, + ) + from anthropic.lib.streaming._types import ( # pylint: disable=no-name-in-module + ParsedMessageStreamEvent, + ) from anthropic.types import ( Message, - MessageDeltaUsage, RawMessageStreamEvent, - Usage, ) + from anthropic.types.parsed_message import ParsedMessage _logger = logging.getLogger(__name__) +ResponseT = TypeVar("ResponseT") +ResponseFormatT = TypeVar("ResponseFormatT") +accumulate_event = cast("Callable[..., Message] | None", _sdk_accumulate_event) + + +def _set_response_attributes( + invocation: LLMInvocation, + result: "Message | None", + capture_content: bool, +) -> None: + set_invocation_response_attributes(invocation, result, capture_content) + + +class _ResponseProxy(Generic[ResponseT]): + def __init__(self, response: ResponseT, finalize: Callable[[], None]): + self._response: Any = response + self._finalize = finalize + + def close(self) -> None: + try: + self._response.close() + finally: + self._finalize() + + def __getattr__(self, name: str): + return getattr(self._response, name) + + +class _AsyncResponseProxy(Generic[ResponseT]): + def __init__(self, response: ResponseT, finalize: Callable[[], None]): + self._response: Any = response + self._finalize = finalize + + async def aclose(self) -> None: + try: + await self._response.aclose() + finally: + self._finalize() + + def __getattr__(self, name: str): + return getattr(self._response, name) class MessageWrapper: @@ -62,33 +114,9 @@ def __init__(self, message: Message, capture_content: bool): def extract_into(self, invocation: LLMInvocation) -> None: """Extract response data into the invocation.""" - if self._message.model: - invocation.response_model_name = self._message.model - - if self._message.id: - invocation.response_id = self._message.id - - finish_reason = normalize_finish_reason(self._message.stop_reason) - if finish_reason: - invocation.finish_reasons = [finish_reason] - - if self._message.usage: - tokens = extract_usage_tokens(self._message.usage) - invocation.input_tokens = tokens.input_tokens - invocation.output_tokens = tokens.output_tokens - if tokens.cache_creation_input_tokens is not None: - invocation.attributes[ - GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS - ] = tokens.cache_creation_input_tokens - if tokens.cache_read_input_tokens is not None: - invocation.attributes[GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS] = ( - tokens.cache_read_input_tokens - ) - - if self._capture_content: - invocation.output_messages = get_output_messages_from_message( - self._message - ) + set_invocation_response_attributes( + invocation, self._message, self._capture_content + ) @property def message(self) -> Message: @@ -96,73 +124,104 @@ def message(self) -> Message: return self._message -class MessagesStreamWrapper(Iterator["RawMessageStreamEvent"]): +class MessagesStreamWrapper( + Generic[ResponseFormatT], + Iterator[ + "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]" + ], +): """Wrapper for Anthropic Stream that handles telemetry.""" def __init__( self, - stream: Stream[RawMessageStreamEvent], + stream: "Stream[RawMessageStreamEvent] | MessageStream[ResponseFormatT]", handler: TelemetryHandler, invocation: LLMInvocation, capture_content: bool, ): - self._stream = stream - self._handler = handler - self._invocation = invocation - self._response_id: Optional[str] = None - self._response_model: Optional[str] = None - self._stop_reason: Optional[str] = None - self._input_tokens: Optional[int] = None - self._output_tokens: Optional[int] = None - self._cache_creation_input_tokens: Optional[int] = None - self._cache_read_input_tokens: Optional[int] = None + self.stream = stream + self.handler = handler + self.invocation = invocation + self._message: "Message | ParsedMessage[ResponseFormatT] | None" = None self._capture_content = capture_content - self._content_blocks: dict[int, StreamBlockState] = {} self._finalized = False - def _update_usage(self, usage: Usage | MessageDeltaUsage | None) -> None: - tokens = extract_usage_tokens(usage) - if tokens.input_tokens is not None: - self._input_tokens = tokens.input_tokens - if tokens.output_tokens is not None: - self._output_tokens = tokens.output_tokens - if tokens.cache_creation_input_tokens is not None: - self._cache_creation_input_tokens = ( - tokens.cache_creation_input_tokens - ) - if tokens.cache_read_input_tokens is not None: - self._cache_read_input_tokens = tokens.cache_read_input_tokens - - def _process_chunk(self, chunk: RawMessageStreamEvent) -> None: - """Extract telemetry data from a streaming chunk.""" - if chunk.type == "message_start": - message = chunk.message - if message.id: - self._response_id = message.id - if message.model: - self._response_model = message.model - self._update_usage(message.usage) - elif chunk.type == "message_delta": - if chunk.delta.stop_reason: - self._stop_reason = normalize_finish_reason( - chunk.delta.stop_reason + def __enter__(self) -> "MessagesStreamWrapper[ResponseFormatT]": + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool: + try: + if exc_type is not None: + self._fail( + str(exc_val), type(exc_val) if exc_val else Exception ) - self._update_usage(chunk.usage) - elif self._capture_content and chunk.type == "content_block_start": - self._content_blocks[chunk.index] = create_stream_block_state( - chunk.content_block + finally: + self.close() + return False + + def close(self) -> None: + try: + self.stream.close() + finally: + self._stop() + + def __iter__(self) -> "MessagesStreamWrapper[ResponseFormatT]": + return self + + def __next__( + self, + ) -> "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]": + try: + chunk = next(self.stream) + except StopIteration: + self._stop() + raise + except Exception as exc: + self._fail(str(exc), type(exc)) + raise + with self._safe_instrumentation("stream chunk processing"): + self._process_chunk(chunk) + return chunk + + def __getattr__(self, name: str) -> object: + return getattr(self.stream, name) + + @property + def response(self): + return _ResponseProxy(self.stream.response, self._stop) + + def _stop(self) -> None: + if self._finalized: + return + with self._safe_instrumentation("response attribute extraction"): + _set_response_attributes( + self.invocation, self._message, self._capture_content ) - elif self._capture_content and chunk.type == "content_block_delta": - block = self._content_blocks.get(chunk.index) - if block is not None: - update_stream_block_state(block, chunk.delta) + with self._safe_instrumentation("stop_llm"): + self.handler.stop_llm(self.invocation) + self._finalized = True + + def _fail(self, message: str, error_type: type[BaseException]) -> None: + if self._finalized: + return + with self._safe_instrumentation("fail_llm"): + self.handler.fail_llm( + self.invocation, Error(message=message, type=error_type) + ) + self._finalized = True @staticmethod + @contextmanager def _safe_instrumentation( - callback: Callable[[], object], context: str - ) -> None: + context: str, + ) -> Generator[None, None, None]: try: - callback() + yield except Exception: # pylint: disable=broad-exception-caught _logger.debug( "Anthropic MessagesStreamWrapper instrumentation error in %s", @@ -170,90 +229,51 @@ def _safe_instrumentation( exc_info=True, ) - def _set_invocation_response_attributes(self) -> None: - """Extract accumulated stream state into the invocation.""" - if self._response_model: - self._invocation.response_model_name = self._response_model - if self._response_id: - self._invocation.response_id = self._response_id - if self._stop_reason: - self._invocation.finish_reasons = [self._stop_reason] - if self._input_tokens is not None: - self._invocation.input_tokens = self._input_tokens - if self._output_tokens is not None: - self._invocation.output_tokens = self._output_tokens - if self._cache_creation_input_tokens is not None: - self._invocation.attributes[ - GEN_AI_USAGE_CACHE_CREATION_INPUT_TOKENS - ] = self._cache_creation_input_tokens - if self._cache_read_input_tokens is not None: - self._invocation.attributes[ - GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS - ] = self._cache_read_input_tokens - - if self._capture_content and self._content_blocks: - parts: list[MessagePart] = [] - for index in sorted(self._content_blocks): - part = stream_block_state_to_part(self._content_blocks[index]) - if part is not None: - parts.append(part) - self._invocation.output_messages = [ - OutputMessage( - role="assistant", - parts=parts, - finish_reason=self._stop_reason or "", - ) - ] - - def _stop(self) -> None: - if self._finalized: - return - self._safe_instrumentation( - self._set_invocation_response_attributes, - "response attribute extraction", - ) - self._safe_instrumentation( - lambda: self._handler.stop_llm(self._invocation), - "stop_llm", + def _process_chunk( + self, + chunk: "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]", + ) -> None: + """Accumulate a final message snapshot from a streaming chunk.""" + snapshot = cast( + "ParsedMessage[ResponseFormatT] | None", + getattr(self.stream, "current_message_snapshot", None), ) - self._finalized = True - - def _fail(self, message: str, error_type: type[BaseException]) -> None: - if self._finalized: + if snapshot is not None: + self._message = snapshot + return + if accumulate_event is None: return - self._safe_instrumentation( - lambda: self._handler.fail_llm( - self._invocation, Error(message=message, type=error_type) + self._message = accumulate_event( + event=cast("RawMessageStreamEvent", chunk), + current_snapshot=cast( + "ParsedMessage[ResponseFormatT] | None", self._message ), - "fail_llm", ) - self._finalized = True - def __iter__(self) -> MessagesStreamWrapper: - return self - def __getattr__(self, name: str) -> object: - return getattr(self._stream, name) +class AsyncMessagesStreamWrapper(MessagesStreamWrapper[ResponseFormatT]): + """Wrapper for async Anthropic Stream that handles telemetry.""" - def __next__(self) -> RawMessageStreamEvent: - try: - chunk = next(self._stream) - except StopIteration: - self._stop() - raise - except Exception as exc: - self._fail(str(exc), type(exc)) - raise - self._safe_instrumentation( - lambda: self._process_chunk(chunk), - "stream chunk processing", - ) - return chunk + def __init__( + self, + stream: "AsyncStream[RawMessageStreamEvent] | AsyncMessageStream[ResponseFormatT]", + handler: TelemetryHandler, + invocation: LLMInvocation, + capture_content: bool, + ): + self.stream = stream + self.handler = handler + self.invocation = invocation + self._message: "Message | ParsedMessage[ResponseFormatT] | None" = None + self._capture_content = capture_content + self._finalized = False - def __enter__(self) -> MessagesStreamWrapper: + async def __aenter__( + self, + ) -> "AsyncMessagesStreamWrapper[ResponseFormatT]": return self - def __exit__( + async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, @@ -265,11 +285,151 @@ def __exit__( str(exc_val), type(exc_val) if exc_val else Exception ) finally: - self.close() + await self.close() return False - def close(self) -> None: + async def close(self) -> None: # type: ignore[override] try: - self._stream.close() + await self.stream.close() finally: self._stop() + + def __aiter__(self) -> "AsyncMessagesStreamWrapper[ResponseFormatT]": + return self + + @property + def response(self) -> Any: + return _AsyncResponseProxy(self.stream.response, self._stop) + + async def __anext__( + self, + ) -> "RawMessageStreamEvent | ParsedMessageStreamEvent[ResponseFormatT]": + try: + chunk = await self.stream.__anext__() + except StopAsyncIteration: + self._stop() + raise + except Exception as exc: + self._fail(str(exc), type(exc)) + raise + with self._safe_instrumentation("stream chunk processing"): + self._process_chunk(chunk) + return chunk + + +class MessagesStreamManagerWrapper(Generic[ResponseFormatT]): + """Wrapper for sync Anthropic stream managers.""" + + def __init__( + self, + manager: "MessageStreamManager[ResponseFormatT]", + handler: TelemetryHandler, + invocation: LLMInvocation, + capture_content: bool, + ): + self._manager = manager + self._handler = handler + self._invocation = invocation + self._capture_content = capture_content + self._stream_wrapper: MessagesStreamWrapper[ResponseFormatT] | None = ( + None + ) + + def __enter__(self) -> MessagesStreamWrapper[ResponseFormatT]: + stream = self._manager.__enter__() + self._stream_wrapper = MessagesStreamWrapper( + stream, + self._handler, + self._invocation, + self._capture_content, + ) + return self._stream_wrapper + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + suppressed = False + stream_wrapper = self._stream_wrapper + self._stream_wrapper = None + with ExitStack() as cleanup: + if stream_wrapper is not None: + + def finalize_stream_wrapper() -> None: + if suppressed: + stream_wrapper.__exit__(None, None, None) + else: + stream_wrapper.__exit__(exc_type, exc_val, exc_tb) + + cleanup.callback(finalize_stream_wrapper) + suppressed = self._manager.__exit__(exc_type, exc_val, exc_tb) + return suppressed + + def __getattr__(self, name: str) -> object: + return getattr(self._manager, name) + + +class AsyncMessagesStreamManagerWrapper(Generic[ResponseFormatT]): + """Wrapper for AsyncMessageStreamManager that handles telemetry. + + Wraps AsyncMessageStreamManager from the Anthropic SDK: + https://github.com/anthropics/anthropic-sdk-python/blob/05220bc1c1079fe01f5c4babc007ec7a990859d9/src/anthropic/lib/streaming/_messages.py#L294 + """ + + def __init__( + self, + manager: "AsyncMessageStreamManager[ResponseFormatT]", + handler: TelemetryHandler, + invocation: LLMInvocation, + capture_content: bool, + ): + self._manager = manager + self._handler = handler + self._invocation = invocation + self._capture_content = capture_content + self._stream_wrapper: ( + AsyncMessagesStreamWrapper[ResponseFormatT] | None + ) = None + + async def __aenter__( + self, + ) -> AsyncMessagesStreamWrapper[ResponseFormatT]: + msg_stream = await self._manager.__aenter__() + self._stream_wrapper = AsyncMessagesStreamWrapper( + msg_stream, + self._handler, + self._invocation, + self._capture_content, + ) + return self._stream_wrapper + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + suppressed = False + stream_wrapper = self._stream_wrapper + self._stream_wrapper = None + async with AsyncExitStack() as cleanup: + if stream_wrapper is not None: + + async def finalize_stream_wrapper() -> None: + if suppressed: + await stream_wrapper.__aexit__(None, None, None) + else: + await stream_wrapper.__aexit__( + exc_type, exc_val, exc_tb + ) + + cleanup.push_async_callback(finalize_stream_wrapper) + suppressed = await self._manager.__aexit__( + exc_type, exc_val, exc_tb + ) + return suppressed + + def __getattr__(self, name: str) -> object: + return getattr(self._manager, name) diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/tests/test_async_wrappers.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/tests/test_async_wrappers.py new file mode 100644 index 0000000000..b2055a525f --- /dev/null +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/tests/test_async_wrappers.py @@ -0,0 +1,536 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace + +import pytest + +from opentelemetry.instrumentation.anthropic.wrappers import ( + AsyncMessagesStreamManagerWrapper, + AsyncMessagesStreamWrapper, + MessagesStreamManagerWrapper, + MessagesStreamWrapper, +) + + +def _noop_stop_llm(invocation): + del invocation + + +def _noop_fail_llm(invocation, error): + del invocation + del error + + +def _make_handler(): + return SimpleNamespace( + stop_llm=_noop_stop_llm, + fail_llm=_noop_fail_llm, + ) + + +def _make_invocation(): + return SimpleNamespace(attributes={}, request_model=None) + + +def _make_stream_wrapper(stream, handler=None): + return MessagesStreamWrapper( + stream=stream, + handler=handler or _make_handler(), + invocation=_make_invocation(), + capture_content=False, + ) + + +def _make_async_stream_wrapper(stream, handler=None): + return AsyncMessagesStreamWrapper( + stream=stream, + handler=handler or _make_handler(), + invocation=_make_invocation(), + capture_content=False, + ) + + +class _FakeSyncStream: + def __init__(self, *, events=None, error=None): + self._events = list(events or []) + self._error = error + self.close_calls = 0 + self.response = _FakeSyncResponse() + + def __iter__(self): + return self + + def __next__(self): + if self._events: + return self._events.pop(0) + if self._error is not None: + raise self._error + raise StopIteration + + def close(self): + self.close_calls += 1 + + +class _FakeAsyncStream: + def __init__(self, *, events=None, error=None): + self._events = list(events or []) + self._error = error + self.close_calls = 0 + self.final_message = SimpleNamespace(id="msg_final") + self.response = _FakeAsyncResponse() + + async def __anext__(self): + if self._events: + return self._events.pop(0) + if self._error is not None: + raise self._error + raise StopAsyncIteration + + async def close(self): + self.close_calls += 1 + + async def get_final_message(self): + return self.final_message + + +class _FakeSyncManager: + def __init__(self, stream, suppressed=False, exit_error=None): + self._stream = stream + self._suppressed = suppressed + self._exit_error = exit_error + self.exit_args = None + + def __enter__(self): + return self._stream + + def __exit__(self, exc_type, exc_val, exc_tb): + self.exit_args = (exc_type, exc_val, exc_tb) + if self._exit_error is not None: + raise self._exit_error + return self._suppressed + + +class _FakeAsyncManager: + def __init__(self, stream, suppressed=False, exit_error=None): + self._stream = stream + self._suppressed = suppressed + self._exit_error = exit_error + self.exit_args = None + + async def __aenter__(self): + return self._stream + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_args = (exc_type, exc_val, exc_tb) + if self._exit_error is not None: + raise self._exit_error + return self._suppressed + + +class _FakeStreamWrapper: + def __init__(self): + self.exit_args = None + + def __exit__(self, exc_type, exc_val, exc_tb): + self.exit_args = (exc_type, exc_val, exc_tb) + return False + + +class _FakeAsyncStreamWrapper: + def __init__(self): + self.exit_args = None + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.exit_args = (exc_type, exc_val, exc_tb) + return False + + +class _FakeSyncResponse: + def __init__(self): + self.request_id = "req_sync" + self.close_calls = 0 + + def close(self): + self.close_calls += 1 + + +class _FakeAsyncResponse: + def __init__(self): + self.request_id = "req_async" + self.close_calls = 0 + self.aclose_calls = 0 + + def close(self): + self.close_calls += 1 + + async def aclose(self): + self.aclose_calls += 1 + + +def test_sync_stream_wrapper_exit_closes_without_exception(): + stream = _FakeSyncStream() + wrapper = _make_stream_wrapper(stream) + stopped = [] + + wrapper._stop = lambda: stopped.append(True) + + result = wrapper.__exit__(None, None, None) + + assert result is False + assert stream.close_calls == 1 + assert stopped == [True] + + +def test_sync_stream_wrapper_exit_fails_and_closes_on_exception(): + stream = _FakeSyncStream() + wrapper = _make_stream_wrapper(stream) + stopped = [] + failures = [] + + wrapper._stop = lambda: stopped.append(True) + wrapper._fail = lambda message, error_type: failures.append( + (message, error_type) + ) + + error = ValueError("boom") + result = wrapper.__exit__(ValueError, error, None) + + assert result is False + assert stream.close_calls == 1 + assert stopped == [True] + assert failures == [("boom", ValueError)] + + +def test_sync_stream_wrapper_processes_events_and_stops_on_completion(): + event = SimpleNamespace(type="message_start") + stream = _FakeSyncStream(events=[event]) + wrapper = _make_stream_wrapper(stream) + processed = [] + stopped = [] + + wrapper._process_chunk = processed.append + wrapper._stop = lambda: stopped.append(True) + + result = next(wrapper) + + assert result is event + assert processed == [event] + + with pytest.raises(StopIteration): + next(wrapper) + + assert stopped == [True] + + +def test_sync_stream_wrapper_fails_and_reraises_stream_errors(): + error = ValueError("boom") + stream = _FakeSyncStream(error=error) + wrapper = _make_stream_wrapper(stream) + failures = [] + + wrapper._fail = lambda message, error_type: failures.append( + (message, error_type) + ) + + with pytest.raises(ValueError, match="boom"): + next(wrapper) + + assert failures == [("boom", ValueError)] + + +def test_sync_stream_wrapper_getattr_passthrough(): + stream = _FakeSyncStream() + wrapper = _make_stream_wrapper(stream) + + assert wrapper.response.request_id == "req_sync" + + +def test_sync_stream_response_close_finalizes_wrapper(): + stream = _FakeSyncStream() + wrapper = _make_stream_wrapper(stream) + stopped = [] + + wrapper._stop = lambda: stopped.append(True) + + wrapper.response.close() + + assert stream.response.close_calls == 1 + assert stopped == [True] + + +def test_sync_manager_enter_constructs_stream_wrapper(): + stream = _FakeSyncStream() + wrapper = MessagesStreamManagerWrapper( + manager=_FakeSyncManager(stream=stream), + handler=_make_handler(), + invocation=_make_invocation(), + capture_content=False, + ) + + with wrapper as result: + assert isinstance(result, MessagesStreamWrapper) + assert result.stream is stream + assert wrapper._stream_wrapper is result + + +def test_sync_manager_exit_forwards_exception_to_stream_wrapper(): + wrapper = MessagesStreamManagerWrapper( + manager=_FakeSyncManager(stream=SimpleNamespace(), suppressed=False), + handler=SimpleNamespace(), + invocation=_make_invocation(), + capture_content=False, + ) + stream_wrapper = _FakeStreamWrapper() + wrapper._stream_wrapper = stream_wrapper + + error = ValueError("boom") + result = wrapper.__exit__(ValueError, error, None) + + assert result is False + assert wrapper._manager.exit_args == (ValueError, error, None) + assert stream_wrapper.exit_args == (ValueError, error, None) + + +def test_sync_manager_exit_uses_none_exception_when_manager_suppresses(): + wrapper = MessagesStreamManagerWrapper( + manager=_FakeSyncManager(stream=SimpleNamespace(), suppressed=True), + handler=SimpleNamespace(), + invocation=_make_invocation(), + capture_content=False, + ) + stream_wrapper = _FakeStreamWrapper() + wrapper._stream_wrapper = stream_wrapper + + error = RuntimeError("ignored") + result = wrapper.__exit__(RuntimeError, error, None) + + assert result is True + assert wrapper._manager.exit_args == (RuntimeError, error, None) + assert stream_wrapper.exit_args == (None, None, None) + + +def test_sync_manager_exit_still_finalizes_stream_wrapper_when_manager_raises(): + manager_error = RuntimeError("manager failure") + wrapper = MessagesStreamManagerWrapper( + manager=_FakeSyncManager( + stream=SimpleNamespace(), + suppressed=False, + exit_error=manager_error, + ), + handler=SimpleNamespace(), + invocation=_make_invocation(), + capture_content=False, + ) + stream_wrapper = _FakeStreamWrapper() + wrapper._stream_wrapper = stream_wrapper + + error = ValueError("outer") + with pytest.raises(RuntimeError, match="manager failure"): + wrapper.__exit__(ValueError, error, None) + + assert wrapper._manager.exit_args == (ValueError, error, None) + assert stream_wrapper.exit_args == (ValueError, error, None) + + +@pytest.mark.asyncio +async def test_async_stream_wrapper_exit_closes_without_exception(): + stream = _FakeAsyncStream() + wrapper = _make_async_stream_wrapper(stream) + stopped = [] + + wrapper._stop = lambda: stopped.append(True) + + result = await wrapper.__aexit__(None, None, None) + + assert result is False + assert stream.close_calls == 1 + assert stopped == [True] + + +@pytest.mark.asyncio +async def test_async_stream_wrapper_exit_fails_and_closes_on_exception(): + stream = _FakeAsyncStream() + wrapper = _make_async_stream_wrapper(stream) + stopped = [] + failures = [] + + wrapper._stop = lambda: stopped.append(True) + wrapper._fail = lambda message, error_type: failures.append( + (message, error_type) + ) + + error = ValueError("boom") + result = await wrapper.__aexit__(ValueError, error, None) + + assert result is False + assert stream.close_calls == 1 + assert stopped == [True] + assert failures == [("boom", ValueError)] + + +@pytest.mark.asyncio +async def test_async_stream_wrapper_close_uses_close_and_stops(): + stream = _FakeAsyncStream() + wrapper = _make_async_stream_wrapper(stream) + stopped = [] + + wrapper._stop = lambda: stopped.append(True) + + await wrapper.close() + + assert stream.close_calls == 1 + assert stopped == [True] + + +@pytest.mark.asyncio +async def test_async_stream_wrapper_processes_events_and_stops_on_completion(): + event = SimpleNamespace(type="message_start") + stream = _FakeAsyncStream(events=[event]) + wrapper = _make_async_stream_wrapper(stream) + processed = [] + stopped = [] + + wrapper._process_chunk = processed.append + wrapper._stop = lambda: stopped.append(True) + + result = await wrapper.__anext__() + + assert result is event + assert processed == [event] + + with pytest.raises(StopAsyncIteration): + await wrapper.__anext__() + + assert stopped == [True] + + +@pytest.mark.asyncio +async def test_async_stream_wrapper_fails_and_reraises_stream_errors(): + error = ValueError("boom") + stream = _FakeAsyncStream(error=error) + wrapper = _make_async_stream_wrapper(stream) + failures = [] + + wrapper._fail = lambda message, error_type: failures.append( + (message, error_type) + ) + + with pytest.raises(ValueError, match="boom"): + await wrapper.__anext__() + + assert failures == [("boom", ValueError)] + + +@pytest.mark.asyncio +async def test_async_stream_wrapper_preserves_stream_helper_methods(): + stream = _FakeAsyncStream() + wrapper = _make_async_stream_wrapper(stream) + + result = await wrapper.get_final_message() + + assert result is stream.final_message + assert wrapper.response.request_id == "req_async" + + +@pytest.mark.asyncio +async def test_async_stream_response_aclose_finalizes_wrapper(): + stream = _FakeAsyncStream() + wrapper = _make_async_stream_wrapper(stream) + stopped = [] + + wrapper._stop = lambda: stopped.append(True) + + await wrapper.response.aclose() + + assert stream.response.aclose_calls == 1 + assert stopped == [True] + + +@pytest.mark.asyncio +async def test_async_manager_enter_constructs_async_stream_wrapper(): + stream = _FakeAsyncStream() + wrapper = AsyncMessagesStreamManagerWrapper( + manager=_FakeAsyncManager(stream=stream), + handler=_make_handler(), + invocation=_make_invocation(), + capture_content=False, + ) + + async with wrapper as result: + assert isinstance(result, AsyncMessagesStreamWrapper) + assert result.stream is stream + assert wrapper._stream_wrapper is result + + +@pytest.mark.asyncio +async def test_async_manager_exit_forwards_exception_to_stream_wrapper(): + wrapper = AsyncMessagesStreamManagerWrapper( + manager=_FakeAsyncManager(stream=SimpleNamespace(), suppressed=False), + handler=SimpleNamespace(), + invocation=_make_invocation(), + capture_content=False, + ) + stream_wrapper = _FakeAsyncStreamWrapper() + wrapper._stream_wrapper = stream_wrapper + + error = ValueError("boom") + result = await wrapper.__aexit__(ValueError, error, None) + + assert result is False + assert wrapper._manager.exit_args == (ValueError, error, None) + assert stream_wrapper.exit_args == (ValueError, error, None) + + +@pytest.mark.asyncio +async def test_async_manager_exit_uses_none_exception_when_manager_suppresses(): + wrapper = AsyncMessagesStreamManagerWrapper( + manager=_FakeAsyncManager(stream=SimpleNamespace(), suppressed=True), + handler=SimpleNamespace(), + invocation=_make_invocation(), + capture_content=False, + ) + stream_wrapper = _FakeAsyncStreamWrapper() + wrapper._stream_wrapper = stream_wrapper + + error = RuntimeError("ignored") + result = await wrapper.__aexit__(RuntimeError, error, None) + + assert result is True + assert wrapper._manager.exit_args == (RuntimeError, error, None) + assert stream_wrapper.exit_args == (None, None, None) + + +@pytest.mark.asyncio +async def test_async_manager_exit_still_finalizes_stream_wrapper_when_manager_raises(): + manager_error = RuntimeError("manager failure") + wrapper = AsyncMessagesStreamManagerWrapper( + manager=_FakeAsyncManager( + stream=SimpleNamespace(), + suppressed=False, + exit_error=manager_error, + ), + handler=SimpleNamespace(), + invocation=_make_invocation(), + capture_content=False, + ) + stream_wrapper = _FakeAsyncStreamWrapper() + wrapper._stream_wrapper = stream_wrapper + + error = ValueError("outer") + with pytest.raises(RuntimeError, match="manager failure"): + await wrapper.__aexit__(ValueError, error, None) + + assert wrapper._manager.exit_args == (ValueError, error, None) + assert stream_wrapper.exit_args == (ValueError, error, None) diff --git a/instrumentation-genai/opentelemetry-instrumentation-anthropic/tests/test_sync_messages.py b/instrumentation-genai/opentelemetry-instrumentation-anthropic/tests/test_sync_messages.py index 6bce6f2e16..a073f3b569 100644 --- a/instrumentation-genai/opentelemetry-instrumentation-anthropic/tests/test_sync_messages.py +++ b/instrumentation-genai/opentelemetry-instrumentation-anthropic/tests/test_sync_messages.py @@ -826,7 +826,7 @@ def __getattr__(self, name): return getattr(self._inner, name) monkeypatch.setattr( - stream, "_stream", ErrorInjectingStreamDelegate(stream._stream) + stream, "stream", ErrorInjectingStreamDelegate(stream.stream) ) with pytest.raises(