Skip to content

Commit c1e7416

Browse files
committed
sdk/python: Anthropic prompt caching
1 parent ca630a0 commit c1e7416

File tree

7 files changed

+213
-9
lines changed

7 files changed

+213
-9
lines changed

sdk/python/polos/agents/agent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ def __init__(
386386
guardrails: Callable | str | list[Callable | str] | None = None,
387387
guardrail_max_retries: int = 2,
388388
conversation_history: int = 10, # Number of messages to keep
389+
stream_to_workflow: bool = False,
389390
):
390391
# Parse queue configuration (same as task decorator)
391392
queue_name: str | None = None
@@ -442,6 +443,9 @@ def __init__(
442443
# Conversation history
443444
self.conversation_history = conversation_history
444445

446+
# Stream to workflow topic for all invocations
447+
self.stream_to_workflow = stream_to_workflow
448+
445449
# Convert Pydantic model to JSON schema if provided
446450
self._output_json_schema, self._output_schema_name = convert_output_schema(
447451
output_schema, context_id=self.id
@@ -505,7 +509,7 @@ async def _agent_execute(self, ctx: AgentContext, payload: dict[str, Any]) -> di
505509
)
506510

507511
input_data = payload.get("input")
508-
streaming = payload.get("streaming", False) # Whether to stream or return final result
512+
streaming = payload.get("streaming", False) or self.stream_to_workflow
509513
provider_kwargs = payload.get(
510514
"provider_kwargs", {}
511515
) # Additional kwargs to pass to provider

sdk/python/polos/agents/stream.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ async def _agent_stream_function(ctx: AgentContext, payload: dict[str, Any]) ->
118118
final_input_tokens = 0
119119
final_output_tokens = 0
120120
final_total_tokens = 0
121+
final_cache_read_input_tokens = 0
122+
final_cache_creation_input_tokens = 0
121123
last_llm_result_content = None
122124
all_tool_results = []
123125
steps: list[Step] = []
@@ -242,6 +244,10 @@ async def _agent_stream_function(ctx: AgentContext, payload: dict[str, Any]) ->
242244
final_input_tokens += usage_dict.get("input_tokens", 0)
243245
final_output_tokens += usage_dict.get("output_tokens", 0)
244246
final_total_tokens += usage_dict.get("total_tokens", 0)
247+
if usage_dict.get("cache_read_input_tokens"):
248+
final_cache_read_input_tokens += usage_dict["cache_read_input_tokens"]
249+
if usage_dict.get("cache_creation_input_tokens"):
250+
final_cache_creation_input_tokens += usage_dict["cache_creation_input_tokens"]
245251

246252
last_llm_result_content = llm_result.get("content")
247253
tool_calls = llm_result.get("tool_calls") or []
@@ -555,6 +561,16 @@ async def _agent_stream_function(ctx: AgentContext, payload: dict[str, Any]) ->
555561
"input_tokens": final_input_tokens,
556562
"output_tokens": final_output_tokens,
557563
"total_tokens": final_total_tokens,
564+
**(
565+
{"cache_read_input_tokens": final_cache_read_input_tokens}
566+
if final_cache_read_input_tokens > 0
567+
else {}
568+
),
569+
**(
570+
{"cache_creation_input_tokens": final_cache_creation_input_tokens}
571+
if final_cache_creation_input_tokens > 0
572+
else {}
573+
),
558574
},
559575
}
560576
)

sdk/python/polos/execution/tools/exec.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ async def _request_approval(
4141
"_form": {
4242
"title": "Approve command execution",
4343
"description": (
44-
f"The agent wants to run a shell command in the "
45-
f"{env_info.type} environment."
44+
f"The agent wants to run a shell command in the {env_info.type} environment."
4645
),
4746
"fields": [
4847
{

sdk/python/polos/llm/providers/anthropic.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,68 @@
11
"""Anthropic provider implementation."""
22

33
import json
4+
import logging
45
import os
56
from typing import Any
67

78
from .base import LLMProvider, LLMResponse, register_provider
89

10+
logger = logging.getLogger(__name__)
11+
12+
ANTHROPIC_CACHE_CONTROL = {"type": "ephemeral"}
13+
14+
15+
def _apply_cache_control(request_params: dict[str, Any]) -> None:
16+
"""Add Anthropic prompt caching breakpoints to request params (in-place).
17+
18+
Marks the system prompt, the last tool, and the last message with
19+
cache_control so Anthropic can cache the static prefix across calls.
20+
"""
21+
# 1. System prompt: convert string to content block list with cache control
22+
system = request_params.get("system")
23+
if isinstance(system, str):
24+
request_params["system"] = [
25+
{"type": "text", "text": system, "cache_control": ANTHROPIC_CACHE_CONTROL}
26+
]
27+
elif isinstance(system, list) and system:
28+
# Already a list of content blocks — mark the last one
29+
system[-1] = {**system[-1], "cache_control": ANTHROPIC_CACHE_CONTROL}
30+
31+
# 2. Tools: mark the last tool
32+
tools = request_params.get("tools")
33+
if tools and isinstance(tools, list) and len(tools) > 0:
34+
tools[-1] = {**tools[-1], "cache_control": ANTHROPIC_CACHE_CONTROL}
35+
36+
# 3. Messages: mark the last content block of the last message
37+
messages = request_params.get("messages")
38+
if messages and isinstance(messages, list) and len(messages) > 0:
39+
last_msg = messages[-1]
40+
content = last_msg.get("content") if isinstance(last_msg, dict) else None
41+
if isinstance(content, str):
42+
# Convert string content to content block list with cache control
43+
messages[-1] = {
44+
**last_msg,
45+
"content": [
46+
{"type": "text", "text": content, "cache_control": ANTHROPIC_CACHE_CONTROL}
47+
],
48+
}
49+
elif isinstance(content, list) and len(content) > 0:
50+
# Mark the last content block
51+
content[-1] = {**content[-1], "cache_control": ANTHROPIC_CACHE_CONTROL}
52+
53+
54+
def _extract_cache_usage(usage_data: Any) -> dict[str, int]:
55+
"""Extract cache token fields from Anthropic usage data."""
56+
result: dict[str, int] = {}
57+
if usage_data:
58+
cache_read = getattr(usage_data, "cache_read_input_tokens", None)
59+
if cache_read is not None:
60+
result["cache_read_input_tokens"] = cache_read
61+
cache_creation = getattr(usage_data, "cache_creation_input_tokens", None)
62+
if cache_creation is not None:
63+
result["cache_creation_input_tokens"] = cache_creation
64+
return result
65+
966

1067
@register_provider("anthropic")
1168
class AnthropicProvider(LLMProvider):
@@ -149,6 +206,10 @@ async def generate(
149206

150207
# Add any additional kwargs
151208
request_params.update(kwargs)
209+
210+
# Apply prompt caching breakpoints
211+
_apply_cache_control(request_params)
212+
152213
try:
153214
# Use the SDK's Messages API
154215
response = await self.client.messages.create(**request_params)
@@ -202,6 +263,7 @@ async def generate(
202263
"total_tokens": (usage_data.input_tokens + usage_data.output_tokens)
203264
if usage_data
204265
else 0,
266+
**_extract_cache_usage(usage_data),
205267
}
206268

207269
# Extract model and stop_reason from response
@@ -344,6 +406,10 @@ async def stream(
344406

345407
# Add any additional kwargs
346408
request_params.update(kwargs)
409+
410+
# Apply prompt caching breakpoints
411+
_apply_cache_control(request_params)
412+
347413
try:
348414
# Use the SDK's Messages API with streaming
349415
stream = await self.client.messages.create(**request_params)
@@ -512,6 +578,14 @@ async def stream(
512578
usage["input_tokens"] = usage_data.get("input_tokens")
513579
if usage_data.get("output_tokens"):
514580
usage["output_tokens"] = usage_data.get("output_tokens")
581+
if usage_data.get("cache_read_input_tokens") is not None:
582+
usage["cache_read_input_tokens"] = usage_data.get(
583+
"cache_read_input_tokens"
584+
)
585+
if usage_data.get("cache_creation_input_tokens") is not None:
586+
usage["cache_creation_input_tokens"] = usage_data.get(
587+
"cache_creation_input_tokens"
588+
)
515589

516590
elif event_type == "message_stop":
517591
# Stream complete - final event

sdk/python/polos/types/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ class Usage(BaseModel):
1111
input_tokens: int = 0
1212
output_tokens: int = 0
1313
total_tokens: int = 0
14+
cache_read_input_tokens: int | None = None
15+
cache_creation_input_tokens: int | None = None
1416

1517

1618
class ToolCallFunction(BaseModel):

sdk/python/tests/unit/test_agents/test_agent.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,3 +613,116 @@ def guardrail2(ctx, guardrail_ctx):
613613
# Test with invalid type
614614
with pytest.raises(TypeError, match="Invalid guardrails type"):
615615
agent._normalize_guardrails(123) # type: ignore
616+
617+
def test_agent_stream_to_workflow_default_false(self):
618+
"""Test Agent stream_to_workflow defaults to False."""
619+
agent = Agent(id="test-agent", model="gpt-4", provider="openai")
620+
assert agent.stream_to_workflow is False
621+
622+
def test_agent_stream_to_workflow_true(self):
623+
"""Test Agent stream_to_workflow can be set to True."""
624+
agent = Agent(
625+
id="test-agent",
626+
model="gpt-4",
627+
provider="openai",
628+
stream_to_workflow=True,
629+
)
630+
assert agent.stream_to_workflow is True
631+
632+
633+
class TestAgentStreamToWorkflow:
634+
"""Tests for stream_to_workflow streaming flag resolution in _agent_execute."""
635+
636+
@pytest.mark.asyncio
637+
async def test_stream_to_workflow_false_payload_streaming_false(self):
638+
"""Default agent: streaming=False in payload → streaming=False passed to stream function."""
639+
agent = Agent(id="test-agent-stw-1", model="gpt-4", provider="openai")
640+
641+
mock_ctx = MagicMock()
642+
mock_ctx.execution_id = "exec-123"
643+
mock_ctx.session_id = "sess-123"
644+
mock_ctx.user_id = "user-123"
645+
mock_ctx.step.uuid = AsyncMock(return_value="conv-123")
646+
647+
with patch(
648+
"polos.agents.stream._agent_stream_function", new_callable=AsyncMock
649+
) as mock_stream:
650+
mock_stream.return_value = {"result": "ok"}
651+
652+
await agent._agent_execute(mock_ctx, {"input": "hello", "streaming": False})
653+
654+
call_args = mock_stream.call_args[0]
655+
assert call_args[1]["streaming"] is False
656+
657+
@pytest.mark.asyncio
658+
async def test_stream_to_workflow_true_payload_streaming_false(self):
659+
"""Agent with stream_to_workflow=True: streaming=False in payload → streaming=True."""
660+
agent = Agent(
661+
id="test-agent-stw-2",
662+
model="gpt-4",
663+
provider="openai",
664+
stream_to_workflow=True,
665+
)
666+
667+
mock_ctx = MagicMock()
668+
mock_ctx.execution_id = "exec-123"
669+
mock_ctx.session_id = "sess-123"
670+
mock_ctx.user_id = "user-123"
671+
mock_ctx.step.uuid = AsyncMock(return_value="conv-123")
672+
673+
with patch(
674+
"polos.agents.stream._agent_stream_function", new_callable=AsyncMock
675+
) as mock_stream:
676+
mock_stream.return_value = {"result": "ok"}
677+
678+
await agent._agent_execute(mock_ctx, {"input": "hello", "streaming": False})
679+
680+
call_args = mock_stream.call_args[0]
681+
assert call_args[1]["streaming"] is True
682+
683+
@pytest.mark.asyncio
684+
async def test_stream_to_workflow_true_payload_streaming_true(self):
685+
"""Agent with stream_to_workflow=True + payload streaming=True → streaming=True."""
686+
agent = Agent(
687+
id="test-agent-stw-3",
688+
model="gpt-4",
689+
provider="openai",
690+
stream_to_workflow=True,
691+
)
692+
693+
mock_ctx = MagicMock()
694+
mock_ctx.execution_id = "exec-123"
695+
mock_ctx.session_id = "sess-123"
696+
mock_ctx.user_id = "user-123"
697+
mock_ctx.step.uuid = AsyncMock(return_value="conv-123")
698+
699+
with patch(
700+
"polos.agents.stream._agent_stream_function", new_callable=AsyncMock
701+
) as mock_stream:
702+
mock_stream.return_value = {"result": "ok"}
703+
704+
await agent._agent_execute(mock_ctx, {"input": "hello", "streaming": True})
705+
706+
call_args = mock_stream.call_args[0]
707+
assert call_args[1]["streaming"] is True
708+
709+
@pytest.mark.asyncio
710+
async def test_stream_to_workflow_false_payload_streaming_true(self):
711+
"""Default agent + payload streaming=True → streaming=True (unchanged)."""
712+
agent = Agent(id="test-agent-stw-4", model="gpt-4", provider="openai")
713+
714+
mock_ctx = MagicMock()
715+
mock_ctx.execution_id = "exec-123"
716+
mock_ctx.session_id = "sess-123"
717+
mock_ctx.user_id = "user-123"
718+
mock_ctx.step.uuid = AsyncMock(return_value="conv-123")
719+
720+
with patch(
721+
"polos.agents.stream._agent_stream_function", new_callable=AsyncMock
722+
) as mock_stream:
723+
mock_stream.return_value = {"result": "ok"}
724+
725+
await agent._agent_execute(mock_ctx, {"input": "hello", "streaming": True})
726+
727+
call_args = mock_stream.call_args[0]
728+
assert call_args[1]["streaming"] is True

sdk/python/tests/unit/test_agents/test_stop_conditions.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,16 +223,12 @@ def test_max_steps_continues_when_count_not_reached(self):
223223
def test_max_steps_default_count(self):
224224
"""Test max_steps uses default count=20."""
225225
config = MaxStepsConfig() # Uses default count=20
226-
ctx = StopConditionContext(
227-
steps=[Step(step=i) for i in range(1, 6)]
228-
)
226+
ctx = StopConditionContext(steps=[Step(step=i) for i in range(1, 6)])
229227
configured = max_steps(config)
230228
result = configured(ctx)
231229
assert result is False # 5 < 20
232230

233-
ctx_at_limit = StopConditionContext(
234-
steps=[Step(step=i) for i in range(1, 21)]
235-
)
231+
ctx_at_limit = StopConditionContext(steps=[Step(step=i) for i in range(1, 21)])
236232
result = configured(ctx_at_limit)
237233
assert result is True # 20 >= 20
238234

0 commit comments

Comments
 (0)