Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 95 additions & 5 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -966,7 +966,7 @@ func TestAnthropicInjectedTools(t *testing.T) {
}

// Build the requirements & make the assertions which are common to all providers.
recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq)
recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq, anthropicToolResultValidator(t))

// Ensure expected tool was invoked with expected input.
toolUsages := recorderClient.RecordedToolUsages()
Expand Down Expand Up @@ -1056,7 +1056,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
}

// Build the requirements & make the assertions which are common to all providers.
recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq)
recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq, openaiChatToolResultValidator(t))

// Ensure expected tool was invoked with expected input.
toolUsages := recorderClient.RecordedToolUsages()
Expand Down Expand Up @@ -1147,9 +1147,99 @@ func TestOpenAIInjectedTools(t *testing.T) {
}
}

// anthropicToolResultValidator returns a request validator that asserts the second
// upstream request contains the assistant's tool_use and user's tool_result messages
// appended by the inner agentic loop. If the raw payload is not kept in sync with
// the structured messages, the second request will be identical to the first.
func anthropicToolResultValidator(t *testing.T) func(*http.Request) {
t.Helper()

var reqNum atomic.Uint32
return func(r *http.Request) {
raw, err := io.ReadAll(r.Body)
require.NoError(t, err)
r.Body = io.NopCloser(bytes.NewReader(raw))

if reqNum.Add(1) != 2 {
return
}

messages := gjson.GetBytes(raw, "messages").Array()

// After the agentic loop the messages must contain at minimum:
// [0] original user message
// [N-2] assistant message with tool_use content block
// [N-1] user message with tool_result content block
require.GreaterOrEqual(t, len(messages), 3,
"second upstream request must contain the original message, assistant tool_use, and user tool_result")

assistantMsg := messages[len(messages)-2]
require.Equal(t, "assistant", assistantMsg.Get("role").Str,
"penultimate message must be from the assistant")
var hasToolUse bool
for _, block := range assistantMsg.Get("content").Array() {
if block.Get("type").Str == "tool_use" {
hasToolUse = true
break
}
}
require.True(t, hasToolUse, "assistant message must contain a tool_use content block")

toolResultMsg := messages[len(messages)-1]
require.Equal(t, "user", toolResultMsg.Get("role").Str,
"last message must be a user message carrying the tool_result")
var hasToolResult bool
for _, block := range toolResultMsg.Get("content").Array() {
if block.Get("type").Str == "tool_result" {
hasToolResult = true
break
}
}
require.True(t, hasToolResult, "user message must contain a tool_result content block")
}
}

// openaiChatToolResultValidator returns a request validator that asserts the second
// upstream request contains the assistant's tool_calls and a role=tool result message
// appended by the inner agentic loop.
func openaiChatToolResultValidator(t *testing.T) func(*http.Request) {
t.Helper()

var reqNum atomic.Uint32
return func(r *http.Request) {
raw, err := io.ReadAll(r.Body)
require.NoError(t, err)
r.Body = io.NopCloser(bytes.NewReader(raw))

if reqNum.Add(1) != 2 {
return
}

messages := gjson.GetBytes(raw, "messages").Array()

// After the agentic loop the messages must contain at minimum:
// [0] original user message
// [N-2] assistant message with tool_calls array
// [N-1] message with role=tool
require.GreaterOrEqual(t, len(messages), 3,
"second upstream request must contain the original message, assistant tool_calls, and tool result")

assistantMsg := messages[len(messages)-2]
require.Equal(t, "assistant", assistantMsg.Get("role").Str,
"penultimate message must be from the assistant")
require.NotEmpty(t, len(assistantMsg.Get("tool_calls").Array()),
"assistant message must contain a tool_calls array")

toolResultMsg := messages[len(messages)-1]
require.Equal(t, "tool", toolResultMsg.Get("role").Str,
"last message must have role=tool")
require.NotEmpty(t, toolResultMsg.Get("tool_call_id").Str,
"tool result message must have a tool_call_id")
}
}

// setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests.
// Kinda fugly right now, we can refactor this later.
func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*testutil.MockRecorder, *callAccumulator, map[string]mcp.ServerProxier, *http.Response) {
func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request, requestValidatorFn func(*http.Request)) (*testutil.MockRecorder, *callAccumulator, map[string]mcp.ServerProxier, *http.Response) {
t.Helper()

arc := txtar.Parse(fixture)
Expand All @@ -1174,7 +1264,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
t.Cleanup(cancel)

// Setup mock server with response mutator for multi-turn interaction.
mockSrv := newMockServer(ctx, t, files, nil, func(reqCount uint32, resp []byte) []byte {
mockSrv := newMockServer(ctx, t, files, requestValidatorFn, func(reqCount uint32, resp []byte) []byte {
if reqCount == 1 {
return resp // First request gets the normal response (with tool call).
}
Expand Down
21 changes: 19 additions & 2 deletions intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,28 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio
i.augmentRequestForBedrock()
}

// Must be after any request augmentation, eg. i.augmentRequestForBedrock() and i.injectTools()
opts = append(opts, option.WithRequestBody("application/json", i.payload))
return anthropic.NewMessageService(opts...), nil
}

// withBody returns a per-request option that sends the current i.payload as the
// request body. This is called for each API request so that the latest payload (including
// any messages appended during the agentic tool loop) is always sent.
Comment thread
dannykopping marked this conversation as resolved.
func (i *interceptionBase) withBody() option.RequestOption {
return option.WithRequestBody("application/json", i.payload)
}

// syncPayloadMessages updates the raw payload's "messages" field to match the given messages.
// This must be called before the next API request in the agentic loop so that
// withBody() picks up the updated messages.
func (i *interceptionBase) syncPayloadMessages(messages []anthropic.MessageParam) error {
var err error
i.payload, err = sjson.SetBytes(i.payload, "messages", messages)
if err != nil {
return fmt.Errorf("sync payload messages: %w", err)
}
return nil
}

func (i *interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) {
if cfg == nil {
return nil, fmt.Errorf("nil config given")
Expand Down
9 changes: 8 additions & 1 deletion intercept/messages/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,13 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
messages.Messages = append(messages.Messages, anthropic.NewUserMessage(toolResult))
}
}

// Sync the raw payload with updated messages so that withBody()
// sends the updated payload on the next iteration.
if err := i.syncPayloadMessages(messages.Messages); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return fmt.Errorf("sync payload for agentic loop: %w", err)
}
}

if resp == nil {
Expand Down Expand Up @@ -308,5 +315,5 @@ func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.Mes
ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer tracing.EndSpanErr(span, &outErr)

return svc.New(ctx, msgParams)
return svc.New(ctx, msgParams, i.withBody())
}
9 changes: 8 additions & 1 deletion intercept/messages/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,13 @@ newStream:
}
}

// Sync the raw payload with updated messages so that withBody()
// sends the updated payload on the next iteration.
if syncErr := i.syncPayloadMessages(messages.Messages); syncErr != nil {
lastErr = fmt.Errorf("sync payload for agentic loop: %w", syncErr)
break
}

// Causes a new stream to be run with updated messages.
isFirst = false
continue newStream
Expand Down Expand Up @@ -551,5 +558,5 @@ func (s *StreamingInterception) newStream(ctx context.Context, svc anthropic.Mes
_, span := s.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer span.End()

return svc.NewStreaming(ctx, messages)
return svc.NewStreaming(ctx, messages, s.withBody())
}
4 changes: 2 additions & 2 deletions trace_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ func TestAnthropicInjectedToolsTrace(t *testing.T) {
}

// Build the requirements & make the assertions which are common to all providers.
recorderClient, _, proxies, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, tc.streaming, configureFn, reqFunc)
recorderClient, _, proxies, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, tc.streaming, configureFn, reqFunc, anthropicToolResultValidator(t))

defer resp.Body.Close()

Expand Down Expand Up @@ -734,7 +734,7 @@ func TestOpenAIInjectedToolsTrace(t *testing.T) {
}

// Build the requirements & make the assertions which are common to all providers.
recorderClient, _, proxies, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, configureFn, reqFunc)
recorderClient, _, proxies, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, configureFn, reqFunc, openaiChatToolResultValidator(t))

defer resp.Body.Close()

Expand Down