Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pkg/fake/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,11 @@ func DefaultMatcher(onError func(err error)) recorder.MatcherFunc {
thinkingConfigRegex := regexp.MustCompile(`"thinkingConfig":\{[^}]*\},?`)
// Normalize OpenAI reasoning config (varies based on NoThinking flag and thinking budget).
reasoningRegex := regexp.MustCompile(`"reasoning":\{[^}]*\},?`)
// Normalize OpenAI tool_choice field. The string form ("auto", "none",
// "required") is now sent explicitly whenever tools are present so that
// strict gateways (LiteLLM) accept the request, but older cassettes were
// recorded without it.
toolChoiceRegex := regexp.MustCompile(`"tool_choice":"[^"]*",?`)

return func(r *http.Request, i cassette.Request) bool {
if r.Body == nil || r.Body == http.NoBody {
Expand Down Expand Up @@ -249,10 +254,12 @@ func DefaultMatcher(onError func(err error)) recorder.MatcherFunc {
normalizedReq = maxTokensRegex.ReplaceAllString(normalizedReq, "")
normalizedReq = thinkingConfigRegex.ReplaceAllString(normalizedReq, "")
normalizedReq = reasoningRegex.ReplaceAllString(normalizedReq, "")
normalizedReq = toolChoiceRegex.ReplaceAllString(normalizedReq, "")
normalizedCassette := callIDRegex.ReplaceAllString(i.Body, "call_ID")
normalizedCassette = maxTokensRegex.ReplaceAllString(normalizedCassette, "")
normalizedCassette = thinkingConfigRegex.ReplaceAllString(normalizedCassette, "")
normalizedCassette = reasoningRegex.ReplaceAllString(normalizedCassette, "")
normalizedCassette = toolChoiceRegex.ReplaceAllString(normalizedCassette, "")

return normalizedReq == normalizedCassette
}
Expand Down
16 changes: 16 additions & 0 deletions pkg/model/provider/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,15 @@ func (c *Client) CreateChatCompletionStream(
}
params.Tools = toolsParam

// Explicitly send tool_choice="auto". The OpenAI spec treats omission as
// equivalent to "auto", but some strict OpenAI-compatible gateways
// (notably LiteLLM) reject requests where tool_choice is missing while
// tools are present. Sending the default value explicitly is
// spec-compliant and preserves the model's autonomy to call tools.
params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{
OfAuto: param.NewOpt("auto"),
}

if c.ModelConfig.ParallelToolCalls != nil {
params.ParallelToolCalls = openai.Bool(*c.ModelConfig.ParallelToolCalls)
}
Expand Down Expand Up @@ -407,6 +416,13 @@ func (c *Client) CreateResponseStream(
}
params.Tools = toolsParam

// Explicitly send tool_choice="auto". See the matching comment in the
// Chat Completions path above for rationale (LiteLLM-style gateways
// reject requests where tool_choice is omitted).
params.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{
OfToolChoiceMode: param.NewOpt(responses.ToolChoiceOptionsAuto),
}

if c.ModelConfig.ParallelToolCalls != nil {
params.ParallelToolCalls = param.NewOpt(*c.ModelConfig.ParallelToolCalls)
}
Expand Down
168 changes: 168 additions & 0 deletions pkg/model/provider/openai/tool_choice_responses_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package openai

import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/config/latest"
"github.com/docker/docker-agent/pkg/environment"
"github.com/docker/docker-agent/pkg/tools"
)

// TestResponsesAPI_ToolChoiceAutoExplicit verifies that when tools are
// provided to the Responses API, the request body explicitly contains
// tool_choice=auto. This mirrors the Chat Completions test and ensures
// both API paths have consistent behavior for strict gateways like LiteLLM.
//
// See https://github.com/docker/docker-agent/issues/2804.
func TestResponsesAPI_ToolChoiceAutoExplicit(t *testing.T) {
t.Parallel()

var (
receivedBody []byte
mu sync.Mutex
)

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
mu.Lock()
receivedBody = body
mu.Unlock()
writeResponsesSSEResponse(w)
}))
defer server.Close()

cfg := &latest.ModelConfig{
Provider: "custom",
Model: "test",
BaseURL: server.URL,
TokenKey: "MY_TOKEN",
ProviderOpts: map[string]any{
"api_type": "openai_responses",
},
}

env := environment.NewMapEnvProvider(map[string]string{
"MY_TOKEN": "secret",
})

client, err := NewClient(t.Context(), cfg, env)
require.NoError(t, err)

requestTools := []tools.Tool{
{
Name: "search",
Description: "Search the web",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"q": map[string]any{"type": "string"},
},
},
},
}

stream, err := client.CreateResponseStream(
t.Context(),
[]chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}},
requestTools,
)
require.NoError(t, err)
defer stream.Close()

for {
if _, err := stream.Recv(); err != nil {
break
}
}

mu.Lock()
defer mu.Unlock()

var payload map[string]json.RawMessage
require.NoError(t, json.Unmarshal(receivedBody, &payload))

raw, ok := payload["tool_choice"]
require.True(t, ok, "tool_choice must be present in the Responses API request body when tools are provided")

var s string
require.NoError(t, json.Unmarshal(raw, &s))
assert.Equal(t, "auto", s)
}

// TestResponsesAPI_NoToolChoiceWithoutTools verifies that when no tools
// are provided to the Responses API, we don't send a tool_choice field.
func TestResponsesAPI_NoToolChoiceWithoutTools(t *testing.T) {
t.Parallel()

var (
receivedBody []byte
mu sync.Mutex
)

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
mu.Lock()
receivedBody = body
mu.Unlock()
writeResponsesSSEResponse(w)
}))
defer server.Close()

cfg := &latest.ModelConfig{
Provider: "custom",
Model: "test",
BaseURL: server.URL,
TokenKey: "MY_TOKEN",
ProviderOpts: map[string]any{
"api_type": "openai_responses",
},
}

env := environment.NewMapEnvProvider(map[string]string{
"MY_TOKEN": "secret",
})

client, err := NewClient(t.Context(), cfg, env)
require.NoError(t, err)

stream, err := client.CreateResponseStream(
t.Context(),
[]chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}},
nil,
)
require.NoError(t, err)
defer stream.Close()

for {
if _, err := stream.Recv(); err != nil {
break
}
}

mu.Lock()
defer mu.Unlock()

var payload map[string]json.RawMessage
require.NoError(t, json.Unmarshal(receivedBody, &payload))

_, ok := payload["tool_choice"]
assert.False(t, ok, "tool_choice must not be present in Responses API when no tools are provided")
}

// writeResponsesSSEResponse writes a minimal valid SSE response for the Responses API
func writeResponsesSSEResponse(w http.ResponseWriter) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`data: {"type":"response.output_item.done","output_item":{"type":"message","role":"assistant","content":[{"type":"text","text":"ok"}]}}`))
_, _ = w.Write([]byte("\n\n"))
_, _ = w.Write([]byte("data: [DONE]\n\n"))
}
160 changes: 160 additions & 0 deletions pkg/model/provider/openai/tool_choice_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package openai

import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/config/latest"
"github.com/docker/docker-agent/pkg/environment"
"github.com/docker/docker-agent/pkg/tools"
)

// TestChatCompletions_ToolChoiceAutoExplicit verifies that when tools are
// provided, the request body explicitly contains tool_choice=auto. This is
// required by some strict OpenAI-compatible gateways (e.g. LiteLLM) that
// reject requests with a missing tool_choice field even though the OpenAI
// spec treats omission as equivalent to "auto".
//
// See https://github.com/docker/docker-agent/issues/2804.
func TestChatCompletions_ToolChoiceAutoExplicit(t *testing.T) {
t.Parallel()

var (
receivedBody []byte
mu sync.Mutex
)

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
mu.Lock()
receivedBody = body
mu.Unlock()
writeSSEResponse(w)
}))
defer server.Close()

cfg := &latest.ModelConfig{
Provider: "custom",
Model: "test",
BaseURL: server.URL,
TokenKey: "MY_TOKEN",
ProviderOpts: map[string]any{
"api_type": "openai_chatcompletions",
},
}

env := environment.NewMapEnvProvider(map[string]string{
"MY_TOKEN": "secret",
})

client, err := NewClient(t.Context(), cfg, env)
require.NoError(t, err)

requestTools := []tools.Tool{
{
Name: "search",
Description: "Search the web",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"q": map[string]any{"type": "string"},
},
},
},
}

stream, err := client.CreateChatCompletionStream(
t.Context(),
[]chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}},
requestTools,
)
require.NoError(t, err)
defer stream.Close()

for {
if _, err := stream.Recv(); err != nil {
break
}
}

mu.Lock()
defer mu.Unlock()

var payload map[string]json.RawMessage
require.NoError(t, json.Unmarshal(receivedBody, &payload))

raw, ok := payload["tool_choice"]
require.True(t, ok, "tool_choice must be present in the request body when tools are provided")

var s string
require.NoError(t, json.Unmarshal(raw, &s))
assert.Equal(t, "auto", s)
}

// TestChatCompletions_NoToolChoiceWithoutTools verifies that when no tools
// are provided we don't send a tool_choice field (which would be invalid).
func TestChatCompletions_NoToolChoiceWithoutTools(t *testing.T) {
t.Parallel()

var (
receivedBody []byte
mu sync.Mutex
)

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
mu.Lock()
receivedBody = body
mu.Unlock()
writeSSEResponse(w)
}))
defer server.Close()

cfg := &latest.ModelConfig{
Provider: "custom",
Model: "test",
BaseURL: server.URL,
TokenKey: "MY_TOKEN",
ProviderOpts: map[string]any{
"api_type": "openai_chatcompletions",
},
}

env := environment.NewMapEnvProvider(map[string]string{
"MY_TOKEN": "secret",
})

client, err := NewClient(t.Context(), cfg, env)
require.NoError(t, err)

stream, err := client.CreateChatCompletionStream(
t.Context(),
[]chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}},
nil,
)
require.NoError(t, err)
defer stream.Close()

for {
if _, err := stream.Recv(); err != nil {
break
}
}

mu.Lock()
defer mu.Unlock()

var payload map[string]json.RawMessage
require.NoError(t, json.Unmarshal(receivedBody, &payload))

_, ok := payload["tool_choice"]
assert.False(t, ok, "tool_choice must not be present when no tools are provided")
}
Loading