Skip to content

Commit 94185e1

Browse files
committed
sdk/python: Fix bug in usage calculation when there are cached tokens
1 parent edeb2f6 commit 94185e1

File tree

3 files changed

+52
-36
lines changed

3 files changed

+52
-36
lines changed

sdk/python/polos/agents/stream.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,9 @@ async def _agent_stream_function(ctx: AgentContext, payload: dict[str, Any]) ->
244244
final_input_tokens += usage_dict.get("input_tokens", 0)
245245
final_output_tokens += usage_dict.get("output_tokens", 0)
246246
final_total_tokens += usage_dict.get("total_tokens", 0)
247-
if usage_dict.get("cache_read_input_tokens"):
247+
if usage_dict.get("cache_read_input_tokens") is not None:
248248
final_cache_read_input_tokens += usage_dict["cache_read_input_tokens"]
249-
if usage_dict.get("cache_creation_input_tokens"):
249+
if usage_dict.get("cache_creation_input_tokens") is not None:
250250
final_cache_creation_input_tokens += usage_dict["cache_creation_input_tokens"]
251251

252252
last_llm_result_content = llm_result.get("content")
@@ -621,6 +621,8 @@ async def _agent_stream_function(ctx: AgentContext, payload: dict[str, Any]) ->
621621
input_tokens=usage_dict.get("input_tokens", 0),
622622
output_tokens=usage_dict.get("output_tokens", 0),
623623
total_tokens=usage_dict.get("total_tokens", 0),
624+
cache_read_input_tokens=usage_dict.get("cache_read_input_tokens"),
625+
cache_creation_input_tokens=usage_dict.get("cache_creation_input_tokens"),
624626
)
625627

626628
agent_result = AgentResult(

sdk/python/polos/llm/providers/anthropic.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -293,14 +293,22 @@ async def generate(
293293
tool_calls.append(tool_call_data)
294294

295295
# Extract usage information
296+
# Anthropic's input_tokens only counts non-cached tokens.
297+
# Total input = input_tokens + cache_read + cache_creation.
296298
usage_data = response.usage
299+
cache_usage = _extract_cache_usage(usage_data)
300+
raw_input = usage_data.input_tokens if usage_data else 0
301+
total_input = (
302+
raw_input
303+
+ cache_usage.get("cache_read_input_tokens", 0)
304+
+ cache_usage.get("cache_creation_input_tokens", 0)
305+
)
306+
output = usage_data.output_tokens if usage_data else 0
297307
usage = {
298-
"input_tokens": usage_data.input_tokens if usage_data else 0,
299-
"output_tokens": usage_data.output_tokens if usage_data else 0,
300-
"total_tokens": (usage_data.input_tokens + usage_data.output_tokens)
301-
if usage_data
302-
else 0,
303-
**_extract_cache_usage(usage_data),
308+
"input_tokens": total_input,
309+
"output_tokens": output,
310+
"total_tokens": total_input + output,
311+
**cache_usage,
304312
}
305313

306314
# Extract model and stop_reason from response
@@ -478,6 +486,7 @@ async def stream(
478486
else json.dumps(event)
479487
)
480488

489+
481490
if event_type == "content_block_start":
482491
# Content block starting - could be text or tool_use
483492
if event.get("content_block"):
@@ -597,38 +606,43 @@ async def stream(
597606
accumulated_signature = ""
598607

599608
elif event_type in ["message_start", "message_delta"]:
600-
# Message delta - contains stop_reason and usage
601-
message = None
602609
if event_type == "message_start":
603610
message = event.get("message")
611+
if message:
612+
response_model = message.get("model") or response_model
613+
stop_reason = message.get("stop_reason") or stop_reason
614+
usage_data = (message or {}).get("usage")
604615
else:
605-
message = event.get("delta")
606-
607-
if message:
608-
response_model = message.get("model") or response_model # Update if present
609-
stop_reason = message.get("stop_reason") or stop_reason # Update if present
610-
611-
if message.get("usage"):
612-
usage_data = message.get("usage")
613-
if usage_data:
614-
if usage_data.get("input_tokens"):
615-
usage["input_tokens"] = usage_data.get("input_tokens")
616-
if usage_data.get("output_tokens"):
617-
usage["output_tokens"] = usage_data.get("output_tokens")
618-
if usage_data.get("cache_read_input_tokens") is not None:
619-
usage["cache_read_input_tokens"] = usage_data.get(
620-
"cache_read_input_tokens"
621-
)
622-
if usage_data.get("cache_creation_input_tokens") is not None:
623-
usage["cache_creation_input_tokens"] = usage_data.get(
624-
"cache_creation_input_tokens"
625-
)
616+
delta = event.get("delta")
617+
if delta:
618+
stop_reason = delta.get("stop_reason") or stop_reason
619+
# usage lives at the top level for message_delta, not inside delta
620+
usage_data = event.get("usage")
621+
622+
if usage_data:
623+
if usage_data.get("input_tokens") is not None:
624+
usage["input_tokens"] = usage_data["input_tokens"]
625+
if usage_data.get("output_tokens") is not None:
626+
usage["output_tokens"] = usage_data["output_tokens"]
627+
if usage_data.get("cache_read_input_tokens") is not None:
628+
usage["cache_read_input_tokens"] = usage_data[
629+
"cache_read_input_tokens"
630+
]
631+
if usage_data.get("cache_creation_input_tokens") is not None:
632+
usage["cache_creation_input_tokens"] = usage_data[
633+
"cache_creation_input_tokens"
634+
]
626635

627636
elif event_type == "message_stop":
628637
# Stream complete - final event
629-
usage["total_tokens"] = usage.get("input_tokens", 0) + usage.get(
630-
"output_tokens", 0
631-
)
638+
# Anthropic's input_tokens only counts non-cached tokens.
639+
# Total input = input_tokens + cache_read + cache_creation.
640+
raw_input = usage.get("input_tokens", 0)
641+
cache_read = usage.get("cache_read_input_tokens", 0)
642+
cache_creation = usage.get("cache_creation_input_tokens", 0)
643+
total_input = raw_input + cache_read + cache_creation
644+
usage["input_tokens"] = total_input
645+
usage["total_tokens"] = total_input + usage.get("output_tokens", 0)
632646
processed_messages.append(
633647
{
634648
"role": "assistant",

sdk/typescript/src/agents/stream.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,10 +302,10 @@ export async function agentStreamFunction(
302302
finalInputTokens += llmResult.usage.input_tokens;
303303
finalOutputTokens += llmResult.usage.output_tokens;
304304
finalTotalTokens += llmResult.usage.total_tokens;
305-
if (llmResult.usage.cache_read_input_tokens) {
305+
if (llmResult.usage.cache_read_input_tokens != null) {
306306
finalCacheReadInputTokens += llmResult.usage.cache_read_input_tokens;
307307
}
308-
if (llmResult.usage.cache_creation_input_tokens) {
308+
if (llmResult.usage.cache_creation_input_tokens != null) {
309309
finalCacheCreationInputTokens += llmResult.usage.cache_creation_input_tokens;
310310
}
311311
}

0 commit comments

Comments
 (0)