From 4a5387cbac6a03ec65d3288078c8e23763a89c11 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Mon, 23 Feb 2026 16:29:45 +0200 Subject: [PATCH 1/2] chore: ensure injected tool invocation results are appended to request body Signed-off-by: Danny Kopping --- bridge_integration_test.go | 100 ++++++++++++++++++++++++++++++-- intercept/messages/base.go | 21 ++++++- intercept/messages/blocking.go | 9 ++- intercept/messages/streaming.go | 9 ++- trace_integration_test.go | 4 +- 5 files changed, 132 insertions(+), 11 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index cceb6ae9..2d12bea0 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -966,7 +966,7 @@ func TestAnthropicInjectedTools(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq) + recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq, anthropicToolResultValidator(t)) // Ensure expected tool was invoked with expected input. toolUsages := recorderClient.RecordedToolUsages() @@ -1056,7 +1056,7 @@ func TestOpenAIInjectedTools(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq) + recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq, openaiChatToolResultValidator(t)) // Ensure expected tool was invoked with expected input. toolUsages := recorderClient.RecordedToolUsages() @@ -1147,9 +1147,99 @@ func TestOpenAIInjectedTools(t *testing.T) { } } +// anthropicToolResultValidator returns a request validator that asserts the second +// upstream request contains the assistant's tool_use and user's tool_result messages +// appended by the inner agentic loop. If the raw payload is not kept in sync with +// the structured messages, the second request will be identical to the first. +func anthropicToolResultValidator(t *testing.T) func(*http.Request) { + t.Helper() + + var reqNum atomic.Uint32 + return func(r *http.Request) { + raw, err := io.ReadAll(r.Body) + require.NoError(t, err) + r.Body = io.NopCloser(bytes.NewReader(raw)) + + if reqNum.Add(1) != 2 { + return + } + + messages := gjson.GetBytes(raw, "messages").Array() + + // After the agentic loop the messages must contain at minimum: + // [0] original user message + // [N-2] assistant message with tool_use content block + // [N-1] user message with tool_result content block + require.GreaterOrEqual(t, len(messages), 3, + "second upstream request must contain the original message, assistant tool_use, and user tool_result") + + assistantMsg := messages[len(messages)-2] + require.Equal(t, "assistant", assistantMsg.Get("role").Str, + "penultimate message must be from the assistant") + var hasToolUse bool + for _, block := range assistantMsg.Get("content").Array() { + if block.Get("type").Str == "tool_use" { + hasToolUse = true + break + } + } + require.True(t, hasToolUse, "assistant message must contain a tool_use content block") + + toolResultMsg := messages[len(messages)-1] + require.Equal(t, "user", toolResultMsg.Get("role").Str, + "last message must be a user message carrying the tool_result") + var hasToolResult bool + for _, block := range toolResultMsg.Get("content").Array() { + if block.Get("type").Str == "tool_result" { + hasToolResult = true + break + } + } + require.True(t, hasToolResult, "user message must contain a tool_result content block") + } +} + +// openaiChatToolResultValidator returns a request validator that asserts the second +// upstream request contains the assistant's tool_calls and a role=tool result message +// appended by the inner agentic loop. +func openaiChatToolResultValidator(t *testing.T) func(*http.Request) { + t.Helper() + + var reqNum atomic.Uint32 + return func(r *http.Request) { + raw, err := io.ReadAll(r.Body) + require.NoError(t, err) + r.Body = io.NopCloser(bytes.NewReader(raw)) + + if reqNum.Add(1) != 2 { + return + } + + messages := gjson.GetBytes(raw, "messages").Array() + + // After the agentic loop the messages must contain at minimum: + // [0] original user message + // [N-2] assistant message with tool_calls array + // [N-1] message with role=tool + require.GreaterOrEqual(t, len(messages), 3, + "second upstream request must contain the original message, assistant tool_calls, and tool result") + + assistantMsg := messages[len(messages)-2] + require.Equal(t, "assistant", assistantMsg.Get("role").Str, + "penultimate message must be from the assistant") + require.Greater(t, len(assistantMsg.Get("tool_calls").Array()), 0, + "assistant message must contain a tool_calls array") + + toolResultMsg := messages[len(messages)-1] + require.Equal(t, "tool", toolResultMsg.Get("role").Str, + "last message must have role=tool") + require.NotEmpty(t, toolResultMsg.Get("tool_call_id").Str, + "tool result message must have a tool_call_id") + } +} + // setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests. -// Kinda fugly right now, we can refactor this later. -func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*testutil.MockRecorder, *callAccumulator, map[string]mcp.ServerProxier, *http.Response) { +func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request, requestValidatorFn func(*http.Request)) (*testutil.MockRecorder, *callAccumulator, map[string]mcp.ServerProxier, *http.Response) { t.Helper() arc := txtar.Parse(fixture) @@ -1174,7 +1264,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu t.Cleanup(cancel) // Setup mock server with response mutator for multi-turn interaction. - mockSrv := newMockServer(ctx, t, files, nil, func(reqCount uint32, resp []byte) []byte { + mockSrv := newMockServer(ctx, t, files, requestValidatorFn, func(reqCount uint32, resp []byte) []byte { if reqCount == 1 { return resp // First request gets the normal response (with tool call). } diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 753bdad0..2e6b9f48 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -182,11 +182,28 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio i.augmentRequestForBedrock() } - // Must be after any request augmentation, eg. i.augmentRequestForBedrock() and i.injectTools() - opts = append(opts, option.WithRequestBody("application/json", i.payload)) return anthropic.NewMessageService(opts...), nil } +// withBody returns a per-request option that sends the current i.payload as the +// request body. This is called for each API request so that the latest payload (including +// any messages appended during the agentic tool loop) is always sent. +func (i *interceptionBase) withBody() option.RequestOption { + return option.WithRequestBody("application/json", i.payload) +} + +// syncPayloadMessages updates the raw payload's "messages" field to match the given messages. +// This must be called before the next API request in the agentic loop so that +// payloadBodyOption() picks up the updated messages. +func (i *interceptionBase) syncPayloadMessages(messages []anthropic.MessageParam) error { + var err error + i.payload, err = sjson.SetBytes(i.payload, "messages", messages) + if err != nil { + return fmt.Errorf("sync payload messages: %w", err) + } + return nil +} + func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) { if cfg == nil { return nil, fmt.Errorf("nil config given") diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index 594af484..4447cd97 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -279,6 +279,13 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req messages.Messages = append(messages.Messages, anthropic.NewUserMessage(toolResult)) } } + + // Sync the raw payload with updated messages so that payloadBodyOption() + // sends the updated payload on the next iteration. + if err := i.syncPayloadMessages(messages.Messages); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return fmt.Errorf("sync payload for agentic loop: %w", err) + } } if resp == nil { @@ -308,5 +315,5 @@ func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.Mes ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) - return svc.New(ctx, msgParams) + return svc.New(ctx, msgParams, i.withBody()) } diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index 627933ca..d8726019 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -389,6 +389,13 @@ newStream: } } + // Sync the raw payload with updated messages so that payloadBodyOption() + // sends the updated payload on the next iteration. + if syncErr := i.syncPayloadMessages(messages.Messages); syncErr != nil { + lastErr = fmt.Errorf("sync payload for agentic loop: %w", syncErr) + break + } + // Causes a new stream to be run with updated messages. isFirst = false continue newStream @@ -551,5 +558,5 @@ func (s *StreamingInterception) newStream(ctx context.Context, svc anthropic.Mes _, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer span.End() - return svc.NewStreaming(ctx, messages) + return svc.NewStreaming(ctx, messages, s.withBody()) } diff --git a/trace_integration_test.go b/trace_integration_test.go index 38c6b27d..3fd0fbe4 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -346,7 +346,7 @@ func TestAnthropicInjectedToolsTrace(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, _, proxies, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, tc.streaming, configureFn, reqFunc) + recorderClient, _, proxies, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, tc.streaming, configureFn, reqFunc, anthropicToolResultValidator(t)) defer resp.Body.Close() @@ -734,7 +734,7 @@ func TestOpenAIInjectedToolsTrace(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, _, proxies, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, configureFn, reqFunc) + recorderClient, _, proxies, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, configureFn, reqFunc, openaiChatToolResultValidator(t)) defer resp.Body.Close() From b7d851ea747dd8668cfefe119170e10288700df4 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Mon, 23 Feb 2026 17:49:37 +0200 Subject: [PATCH 2/2] chore: review comments Signed-off-by: Danny Kopping --- bridge_integration_test.go | 2 +- intercept/messages/base.go | 2 +- intercept/messages/blocking.go | 2 +- intercept/messages/streaming.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 2d12bea0..ea4c0ae8 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -1227,7 +1227,7 @@ func openaiChatToolResultValidator(t *testing.T) func(*http.Request) { assistantMsg := messages[len(messages)-2] require.Equal(t, "assistant", assistantMsg.Get("role").Str, "penultimate message must be from the assistant") - require.Greater(t, len(assistantMsg.Get("tool_calls").Array()), 0, + require.NotEmpty(t, len(assistantMsg.Get("tool_calls").Array()), "assistant message must contain a tool_calls array") toolResultMsg := messages[len(messages)-1] diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 2e6b9f48..6f4f01fd 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -194,7 +194,7 @@ func (i *interceptionBase) withBody() option.RequestOption { // syncPayloadMessages updates the raw payload's "messages" field to match the given messages. // This must be called before the next API request in the agentic loop so that -// payloadBodyOption() picks up the updated messages. +// withBody() picks up the updated messages. func (i *interceptionBase) syncPayloadMessages(messages []anthropic.MessageParam) error { var err error i.payload, err = sjson.SetBytes(i.payload, "messages", messages) diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index 4447cd97..7ab2bedf 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -280,7 +280,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req } } - // Sync the raw payload with updated messages so that payloadBodyOption() + // Sync the raw payload with updated messages so that withBody() // sends the updated payload on the next iteration. if err := i.syncPayloadMessages(messages.Messages); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index d8726019..4fc19fdf 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -389,7 +389,7 @@ newStream: } } - // Sync the raw payload with updated messages so that payloadBodyOption() + // Sync the raw payload with updated messages so that withBody() // sends the updated payload on the next iteration. if syncErr := i.syncPayloadMessages(messages.Messages); syncErr != nil { lastErr = fmt.Errorf("sync payload for agentic loop: %w", syncErr)