From 036d2d79e94b324a56eded1fd3182d4dc7f3803f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 19 Feb 2026 09:52:52 +0000 Subject: [PATCH 1/2] fix: fix adaptive thinking being removed from request --- bridge_integration_test.go | 60 +++++++++++++++++++++++++++++++++ intercept/messages/base.go | 24 +++++++++++-- intercept/messages/blocking.go | 3 +- intercept/messages/streaming.go | 3 +- provider/anthropic.go | 13 ++++--- 5 files changed, 95 insertions(+), 8 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index b54a3bb4..cceb6ae9 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -1732,6 +1732,66 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { } } +func TestThinkingAdaptiveIsPreserved(t *testing.T) { + t.Parallel() + + arc := txtar.Parse(fixtures.AntSimple) + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureStreamingResponse) + require.Contains(t, files, fixtureNonStreamingResponse) + + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { + t.Parallel() + + // Inject adaptive thinking into the fixture request. + reqBody, err := sjson.SetBytes(files[fixtureRequest], "thinking", map[string]string{"type": "adaptive"}) + require.NoError(t, err) + reqBody, err = sjson.SetBytes(reqBody, "stream", streaming) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + var receivedRequest []byte + + // Create a mock server that captures the request body sent upstream. + srv := newMockServer(ctx, t, files, func(r *http.Request) { + raw, err := io.ReadAll(r.Body) + require.NoError(t, err) + r.Body = io.NopCloser(bytes.NewReader(raw)) + receivedRequest = raw + }, nil) + t.Cleanup(srv.Close) + + recorderClient := &testutil.MockRecorder{} + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(srv.URL, apiKey), nil)} + bridge, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + require.NoError(t, err) + + bridgeSrv := httptest.NewUnstartedServer(bridge) + bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibcontext.AsActor(ctx, userID, nil) + } + bridgeSrv.Start() + t.Cleanup(bridgeSrv.Close) + + req := createAnthropicMessagesReq(t, bridgeSrv.URL, reqBody) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + _, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + + // Verify the thinking field was preserved in the upstream request. + require.NotEmpty(t, receivedRequest) + assert.Equal(t, "adaptive", gjson.GetBytes(receivedRequest, "thinking.type").Str) + }) + } +} + func TestEnvironmentDoNotLeak(t *testing.T) { // NOTE: Cannot use t.Parallel() here because subtests use t.Setenv which requires sequential execution. diff --git a/intercept/messages/base.go b/intercept/messages/base.go index 6db64fa9..753bdad0 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -23,6 +23,7 @@ import ( "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" "github.com/coder/quartz" + "github.com/tidwall/sjson" "github.com/google/uuid" "go.opentelemetry.io/otel/attribute" @@ -32,8 +33,9 @@ import ( ) type interceptionBase struct { - id uuid.UUID - req *MessageNewParamsWrapper + id uuid.UUID + req *MessageNewParamsWrapper + payload []byte cfg aibconfig.Anthropic bedrockCfg *aibconfig.AWSBedrock @@ -115,6 +117,12 @@ func (i *interceptionBase) injectTools() { // any cache invalidation when prepended. i.req.Tools = append(injectedTools, i.req.Tools...) + var err error + i.payload, err = sjson.SetBytes(i.payload, "tools", i.req.Tools) + if err != nil { + i.logger.Warn(context.Background(), "failed to set inject tools in request payload", slog.Error(err)) + } + // Note: Parallel tool calls are disabled to avoid tool_use/tool_result block mismatches. // https://github.com/coder/aibridge/issues/2 toolChoiceType := i.req.ToolChoice.GetType() @@ -140,6 +148,10 @@ func (i *interceptionBase) injectTools() { case string(constant.ValueOf[constant.None]()): // No-op; if tool_choice=none then tools are not used at all. } + i.payload, err = sjson.SetBytes(i.payload, "tool_choice", i.req.ToolChoice) + if err != nil { + i.logger.Warn(context.Background(), "failed to set tool_choice in request payload", slog.Error(err)) + } } // IsSmallFastModel checks if the model is a small/fast model (Haiku 3.5). @@ -170,6 +182,8 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio i.augmentRequestForBedrock() } + // Must be after any request augmentation, eg. i.augmentRequestForBedrock() and i.injectTools() + opts = append(opts, option.WithRequestBody("application/json", i.payload)) return anthropic.NewMessageService(opts...), nil } @@ -228,6 +242,12 @@ func (i *interceptionBase) augmentRequestForBedrock() { } i.req.MessageNewParams.Model = anthropic.Model(i.Model()) + + var err error + i.payload, err = sjson.SetBytes(i.payload, "model", i.Model()) + if err != nil { + i.logger.Warn(context.Background(), "failed to set model in request payload for Bedrock", slog.Error(err)) + } } // writeUpstreamError marshals and writes a given error. diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index f027e497..594af484 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -28,10 +28,11 @@ type BlockingInterception struct { interceptionBase } -func NewBlockingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *BlockingInterception { +func NewBlockingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ id: id, req: req, + payload: payload, cfg: cfg, bedrockCfg: bedrockCfg, tracer: tracer, diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index 31d0ea9a..627933ca 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -34,10 +34,11 @@ type StreamingInterception struct { interceptionBase } -func NewStreamingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *StreamingInterception { +func NewStreamingInterceptor(id uuid.UUID, req *MessageNewParamsWrapper, payload []byte, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, tracer trace.Tracer) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ id: id, req: req, + payload: payload, cfg: cfg, bedrockCfg: bedrockCfg, tracer: tracer, diff --git a/provider/anthropic.go b/provider/anthropic.go index c845bde9..800f3a9b 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -3,6 +3,7 @@ package provider import ( "encoding/json" "fmt" + "io" "net/http" "os" "strings" @@ -89,16 +90,20 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr path := strings.TrimPrefix(r.URL.Path, p.RoutePrefix()) switch path { case routeMessages: + payload, err := io.ReadAll(r.Body) + if err != nil { + return nil, fmt.Errorf("read body: %w", err) + } var req messages.MessageNewParamsWrapper - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, fmt.Errorf("failed to unmarshal request: %w", err) + if err := json.Unmarshal(payload, &req); err != nil { + return nil, fmt.Errorf("unmarshal request body: %w", err) } var interceptor intercept.Interceptor if req.Stream { - interceptor = messages.NewStreamingInterceptor(id, &req, p.cfg, p.bedrockCfg, tracer) + interceptor = messages.NewStreamingInterceptor(id, &req, payload, p.cfg, p.bedrockCfg, tracer) } else { - interceptor = messages.NewBlockingInterceptor(id, &req, p.cfg, p.bedrockCfg, tracer) + interceptor = messages.NewBlockingInterceptor(id, &req, payload, p.cfg, p.bedrockCfg, tracer) } span.SetAttributes(interceptor.TraceAttributes(r)...) return interceptor, nil From 58d2ebfa498f89dfe7049aa940dc6c4cc2a9b5c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Thu, 19 Feb 2026 16:05:56 +0000 Subject: [PATCH 2/2] review: json.Unmarshal -> json.NewDecoder --- provider/anthropic.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/provider/anthropic.go b/provider/anthropic.go index 800f3a9b..be12583b 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -1,6 +1,7 @@ package provider import ( + "bytes" "encoding/json" "fmt" "io" @@ -95,7 +96,7 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr return nil, fmt.Errorf("read body: %w", err) } var req messages.MessageNewParamsWrapper - if err := json.Unmarshal(payload, &req); err != nil { + if err := json.NewDecoder(bytes.NewReader(payload)).Decode(&req); err != nil { return nil, fmt.Errorf("unmarshal request body: %w", err) }