|
1 | 1 | """Anthropic provider implementation.""" |
2 | 2 |
|
3 | 3 | import json |
| 4 | +import logging |
4 | 5 | import os |
5 | 6 | from typing import Any |
6 | 7 |
|
7 | 8 | from .base import LLMProvider, LLMResponse, register_provider |
8 | 9 |
|
| 10 | +logger = logging.getLogger(__name__) |
| 11 | + |
| 12 | +ANTHROPIC_CACHE_CONTROL = {"type": "ephemeral"} |
| 13 | + |
| 14 | + |
| 15 | +def _apply_cache_control(request_params: dict[str, Any]) -> None: |
| 16 | + """Add Anthropic prompt caching breakpoints to request params (in-place). |
| 17 | +
|
| 18 | + Marks the system prompt, the last tool, and the last message with |
| 19 | + cache_control so Anthropic can cache the static prefix across calls. |
| 20 | + """ |
| 21 | + # 1. System prompt: convert string to content block list with cache control |
| 22 | + system = request_params.get("system") |
| 23 | + if isinstance(system, str): |
| 24 | + request_params["system"] = [ |
| 25 | + {"type": "text", "text": system, "cache_control": ANTHROPIC_CACHE_CONTROL} |
| 26 | + ] |
| 27 | + elif isinstance(system, list) and system: |
| 28 | + # Already a list of content blocks — mark the last one |
| 29 | + system[-1] = {**system[-1], "cache_control": ANTHROPIC_CACHE_CONTROL} |
| 30 | + |
| 31 | + # 2. Tools: mark the last tool |
| 32 | + tools = request_params.get("tools") |
| 33 | + if tools and isinstance(tools, list) and len(tools) > 0: |
| 34 | + tools[-1] = {**tools[-1], "cache_control": ANTHROPIC_CACHE_CONTROL} |
| 35 | + |
| 36 | + # 3. Messages: mark the last content block of the last message |
| 37 | + messages = request_params.get("messages") |
| 38 | + if messages and isinstance(messages, list) and len(messages) > 0: |
| 39 | + last_msg = messages[-1] |
| 40 | + content = last_msg.get("content") if isinstance(last_msg, dict) else None |
| 41 | + if isinstance(content, str): |
| 42 | + # Convert string content to content block list with cache control |
| 43 | + messages[-1] = { |
| 44 | + **last_msg, |
| 45 | + "content": [ |
| 46 | + {"type": "text", "text": content, "cache_control": ANTHROPIC_CACHE_CONTROL} |
| 47 | + ], |
| 48 | + } |
| 49 | + elif isinstance(content, list) and len(content) > 0: |
| 50 | + # Mark the last content block |
| 51 | + content[-1] = {**content[-1], "cache_control": ANTHROPIC_CACHE_CONTROL} |
| 52 | + |
| 53 | + |
| 54 | +def _extract_cache_usage(usage_data: Any) -> dict[str, int]: |
| 55 | + """Extract cache token fields from Anthropic usage data.""" |
| 56 | + result: dict[str, int] = {} |
| 57 | + if usage_data: |
| 58 | + cache_read = getattr(usage_data, "cache_read_input_tokens", None) |
| 59 | + if cache_read is not None: |
| 60 | + result["cache_read_input_tokens"] = cache_read |
| 61 | + cache_creation = getattr(usage_data, "cache_creation_input_tokens", None) |
| 62 | + if cache_creation is not None: |
| 63 | + result["cache_creation_input_tokens"] = cache_creation |
| 64 | + return result |
| 65 | + |
9 | 66 |
|
10 | 67 | @register_provider("anthropic") |
11 | 68 | class AnthropicProvider(LLMProvider): |
@@ -149,6 +206,10 @@ async def generate( |
149 | 206 |
|
150 | 207 | # Add any additional kwargs |
151 | 208 | request_params.update(kwargs) |
| 209 | + |
| 210 | + # Apply prompt caching breakpoints |
| 211 | + _apply_cache_control(request_params) |
| 212 | + |
152 | 213 | try: |
153 | 214 | # Use the SDK's Messages API |
154 | 215 | response = await self.client.messages.create(**request_params) |
@@ -202,6 +263,7 @@ async def generate( |
202 | 263 | "total_tokens": (usage_data.input_tokens + usage_data.output_tokens) |
203 | 264 | if usage_data |
204 | 265 | else 0, |
| 266 | + **_extract_cache_usage(usage_data), |
205 | 267 | } |
206 | 268 |
|
207 | 269 | # Extract model and stop_reason from response |
@@ -344,6 +406,10 @@ async def stream( |
344 | 406 |
|
345 | 407 | # Add any additional kwargs |
346 | 408 | request_params.update(kwargs) |
| 409 | + |
| 410 | + # Apply prompt caching breakpoints |
| 411 | + _apply_cache_control(request_params) |
| 412 | + |
347 | 413 | try: |
348 | 414 | # Use the SDK's Messages API with streaming |
349 | 415 | stream = await self.client.messages.create(**request_params) |
@@ -512,6 +578,14 @@ async def stream( |
512 | 578 | usage["input_tokens"] = usage_data.get("input_tokens") |
513 | 579 | if usage_data.get("output_tokens"): |
514 | 580 | usage["output_tokens"] = usage_data.get("output_tokens") |
| 581 | + if usage_data.get("cache_read_input_tokens") is not None: |
| 582 | + usage["cache_read_input_tokens"] = usage_data.get( |
| 583 | + "cache_read_input_tokens" |
| 584 | + ) |
| 585 | + if usage_data.get("cache_creation_input_tokens") is not None: |
| 586 | + usage["cache_creation_input_tokens"] = usage_data.get( |
| 587 | + "cache_creation_input_tokens" |
| 588 | + ) |
515 | 589 |
|
516 | 590 | elif event_type == "message_stop": |
517 | 591 | # Stream complete - final event |
|
0 commit comments