@@ -293,14 +293,22 @@ async def generate(
293293 tool_calls .append (tool_call_data )
294294
295295 # Extract usage information
296+ # Anthropic's input_tokens only counts non-cached tokens.
297+ # Total input = input_tokens + cache_read + cache_creation.
296298 usage_data = response .usage
299+ cache_usage = _extract_cache_usage (usage_data )
300+ raw_input = usage_data .input_tokens if usage_data else 0
301+ total_input = (
302+ raw_input
303+ + cache_usage .get ("cache_read_input_tokens" , 0 )
304+ + cache_usage .get ("cache_creation_input_tokens" , 0 )
305+ )
306+ output = usage_data .output_tokens if usage_data else 0
297307 usage = {
298- "input_tokens" : usage_data .input_tokens if usage_data else 0 ,
299- "output_tokens" : usage_data .output_tokens if usage_data else 0 ,
300- "total_tokens" : (usage_data .input_tokens + usage_data .output_tokens )
301- if usage_data
302- else 0 ,
303- ** _extract_cache_usage (usage_data ),
308+ "input_tokens" : total_input ,
309+ "output_tokens" : output ,
310+ "total_tokens" : total_input + output ,
311+ ** cache_usage ,
304312 }
305313
306314 # Extract model and stop_reason from response
@@ -478,6 +486,7 @@ async def stream(
478486 else json .dumps (event )
479487 )
480488
489+
481490 if event_type == "content_block_start" :
482491 # Content block starting - could be text or tool_use
483492 if event .get ("content_block" ):
@@ -597,38 +606,43 @@ async def stream(
597606 accumulated_signature = ""
598607
599608 elif event_type in ["message_start" , "message_delta" ]:
600- # Message delta - contains stop_reason and usage
601- message = None
602609 if event_type == "message_start" :
603610 message = event .get ("message" )
611+ if message :
612+ response_model = message .get ("model" ) or response_model
613+ stop_reason = message .get ("stop_reason" ) or stop_reason
614+ usage_data = (message or {}).get ("usage" )
604615 else :
605- message = event .get ("delta" )
606-
607- if message :
608- response_model = message .get ("model" ) or response_model # Update if present
609- stop_reason = message .get ("stop_reason" ) or stop_reason # Update if present
610-
611- if message .get ("usage" ):
612- usage_data = message .get ("usage" )
613- if usage_data :
614- if usage_data .get ("input_tokens" ):
615- usage ["input_tokens" ] = usage_data .get ("input_tokens" )
616- if usage_data .get ("output_tokens" ):
617- usage ["output_tokens" ] = usage_data .get ("output_tokens" )
618- if usage_data .get ("cache_read_input_tokens" ) is not None :
619- usage ["cache_read_input_tokens" ] = usage_data .get (
620- "cache_read_input_tokens"
621- )
622- if usage_data .get ("cache_creation_input_tokens" ) is not None :
623- usage ["cache_creation_input_tokens" ] = usage_data .get (
624- "cache_creation_input_tokens"
625- )
616+ delta = event .get ("delta" )
617+ if delta :
618+ stop_reason = delta .get ("stop_reason" ) or stop_reason
619+ # usage lives at the top level for message_delta, not inside delta
620+ usage_data = event .get ("usage" )
621+
622+ if usage_data :
623+ if usage_data .get ("input_tokens" ) is not None :
624+ usage ["input_tokens" ] = usage_data ["input_tokens" ]
625+ if usage_data .get ("output_tokens" ) is not None :
626+ usage ["output_tokens" ] = usage_data ["output_tokens" ]
627+ if usage_data .get ("cache_read_input_tokens" ) is not None :
628+ usage ["cache_read_input_tokens" ] = usage_data [
629+ "cache_read_input_tokens"
630+ ]
631+ if usage_data .get ("cache_creation_input_tokens" ) is not None :
632+ usage ["cache_creation_input_tokens" ] = usage_data [
633+ "cache_creation_input_tokens"
634+ ]
626635
627636 elif event_type == "message_stop" :
628637 # Stream complete - final event
629- usage ["total_tokens" ] = usage .get ("input_tokens" , 0 ) + usage .get (
630- "output_tokens" , 0
631- )
638+ # Anthropic's input_tokens only counts non-cached tokens.
639+ # Total input = input_tokens + cache_read + cache_creation.
640+ raw_input = usage .get ("input_tokens" , 0 )
641+ cache_read = usage .get ("cache_read_input_tokens" , 0 )
642+ cache_creation = usage .get ("cache_creation_input_tokens" , 0 )
643+ total_input = raw_input + cache_read + cache_creation
644+ usage ["input_tokens" ] = total_input
645+ usage ["total_tokens" ] = total_input + usage .get ("output_tokens" , 0 )
632646 processed_messages .append (
633647 {
634648 "role" : "assistant" ,
0 commit comments