Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions pkg/acp/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/config"
"github.com/docker/docker-agent/pkg/model/provider/base"
"github.com/docker/docker-agent/pkg/modelsdev"
"github.com/docker/docker-agent/pkg/session"
"github.com/docker/docker-agent/pkg/team"
"github.com/docker/docker-agent/pkg/tools"
Expand All @@ -38,11 +39,11 @@ func (m *mockStream) Close() {}

// mockProvider returns a predetermined stream for testing.
type mockProvider struct {
id string
id modelsdev.ID
stream chat.MessageStream
}

func (m *mockProvider) ID() string { return m.id }
func (m *mockProvider) ID() modelsdev.ID { return m.id }

func (m *mockProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) {
return m.stream, nil
Expand Down Expand Up @@ -85,7 +86,7 @@ func TestACPSessionPersistence(t *testing.T) {
},
},
}
prov := &mockProvider{id: "test/mock-model", stream: stream}
prov := &mockProvider{id: modelsdev.NewID("test", "mock-model"), stream: stream}

// Create a minimal team with a root agent
root := agent.New("root", "You are a test agent", agent.WithModel(prov))
Expand Down
2 changes: 1 addition & 1 deletion pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (a *Agent) SetModelOverride(models ...provider.Provider) ModelOverrideSnaps
a.modelOverrides.Store(ptr)
ids := make([]string, len(validModels))
for i, m := range validModels {
ids[i] = m.ID()
ids[i] = m.ID().String()
}
slog.Debug("Set model override", "agent", a.name, "models", ids)
}
Expand Down
61 changes: 31 additions & 30 deletions pkg/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/concurrent"
"github.com/docker/docker-agent/pkg/model/provider/base"
"github.com/docker/docker-agent/pkg/modelsdev"
"github.com/docker/docker-agent/pkg/tools"
)

Expand Down Expand Up @@ -127,10 +128,10 @@ func TestAgentTools(t *testing.T) {

// mockProvider implements provider.Provider for testing
type mockProvider struct {
id string
id modelsdev.ID
}

func (m *mockProvider) ID() string { return m.id }
func (m *mockProvider) ID() modelsdev.ID { return m.id }
func (m *mockProvider) CreateChatCompletionStream(_ context.Context, _ []chat.Message, _ []tools.Tool) (chat.MessageStream, error) {
return nil, nil
}
Expand All @@ -139,29 +140,29 @@ func (m *mockProvider) BaseConfig() base.Config { return base.Config{} }
func TestModelOverride(t *testing.T) {
t.Parallel()

defaultModel := &mockProvider{id: "openai/gpt-4o"}
overrideModel := &mockProvider{id: "anthropic/claude-sonnet-4-0"}
defaultModel := &mockProvider{id: modelsdev.NewID("openai", "gpt-4o")}
overrideModel := &mockProvider{id: modelsdev.NewID("anthropic", "claude-sonnet-4-0")}

a := New("root", "test", WithModel(defaultModel))

// Initially should return the default model
assert.Equal(t, "openai/gpt-4o", a.Model(t.Context()).ID())
assert.Equal(t, "openai/gpt-4o", a.Model(t.Context()).ID().String())
assert.False(t, a.HasModelOverride())

// Set an override
a.SetModelOverride(overrideModel)
assert.True(t, a.HasModelOverride())
assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model(t.Context()).ID())
assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model(t.Context()).ID().String())

// ConfiguredModels still reflects the originally configured models
configuredModels := a.ConfiguredModels()
require.Len(t, configuredModels, 1)
assert.Equal(t, "openai/gpt-4o", configuredModels[0].ID())
assert.Equal(t, "openai/gpt-4o", configuredModels[0].ID().String())

// Clear the override
a.SetModelOverride(nil)
assert.False(t, a.HasModelOverride())
assert.Equal(t, "openai/gpt-4o", a.Model(t.Context()).ID())
assert.Equal(t, "openai/gpt-4o", a.Model(t.Context()).ID().String())
}

func TestSetModelOverride_ReturnsSnapshotOfStoredValue(t *testing.T) {
Expand All @@ -174,9 +175,9 @@ func TestSetModelOverride_ReturnsSnapshotOfStoredValue(t *testing.T) {
// change wins) instead of incorrectly succeeding.
t.Parallel()

defaultModel := &mockProvider{id: "default"}
oursModel := &mockProvider{id: "ours"}
othersModel := &mockProvider{id: "others"}
defaultModel := &mockProvider{id: modelsdev.NewID("default", "x")}
oursModel := &mockProvider{id: modelsdev.NewID("ours", "x")}
othersModel := &mockProvider{id: modelsdev.NewID("others", "x")}

a := New("root", "test", WithModel(defaultModel))

Expand All @@ -187,27 +188,27 @@ func TestSetModelOverride_ReturnsSnapshotOfStoredValue(t *testing.T) {
// Simulate a concurrent caller storing a different override _after_ we
// stored ours but _before_ a hypothetical post-store SnapshotModelOverride.
a.SetModelOverride(othersModel)
require.Equal(t, "others", a.Model(t.Context()).ID())
require.Equal(t, "others/x", a.Model(t.Context()).ID().String())

// The deferred restore must be a no-op because oursSnap holds the
// pointer we stored, not the current pointer.
a.RestoreModelOverride(prev, oursSnap)
assert.Equal(t, "others", a.Model(t.Context()).ID(),
assert.Equal(t, "others/x", a.Model(t.Context()).ID().String(),
"concurrent override must be preserved; the snapshot returned by SetModelOverride captures the stored pointer")
}

func TestSetModelOverride_ClearReturnsZeroSnapshot(t *testing.T) {
t.Parallel()

a := New("root", "test", WithModel(&mockProvider{id: "default"}))
a := New("root", "test", WithModel(&mockProvider{id: modelsdev.NewID("default", "x")}))

// Calling SetModelOverride with no providers (or nil) clears the override.
// The returned snapshot should round-trip cleanly through RestoreModelOverride.
cleared := a.SetModelOverride()
assert.False(t, a.HasModelOverride())

// Now set an override and restore using `cleared` as `prev`.
oursSnap := a.SetModelOverride(&mockProvider{id: "ours"})
oursSnap := a.SetModelOverride(&mockProvider{id: modelsdev.NewID("ours", "x")})
require.True(t, a.HasModelOverride())

a.RestoreModelOverride(cleared, oursSnap)
Expand All @@ -217,9 +218,9 @@ func TestSetModelOverride_ClearReturnsZeroSnapshot(t *testing.T) {
func TestSnapshotAndRestoreModelOverride(t *testing.T) {
t.Parallel()

defaultModel := &mockProvider{id: "openai/gpt-4o"}
skillModel := &mockProvider{id: "openai/gpt-4o-mini"}
userModel := &mockProvider{id: "anthropic/claude-sonnet-4-0"}
defaultModel := &mockProvider{id: modelsdev.NewID("openai", "gpt-4o")}
skillModel := &mockProvider{id: modelsdev.NewID("openai", "gpt-4o-mini")}
userModel := &mockProvider{id: modelsdev.NewID("anthropic", "claude-sonnet-4-0")}

t.Run("restores when no concurrent change", func(t *testing.T) {
t.Parallel()
Expand All @@ -228,11 +229,11 @@ func TestSnapshotAndRestoreModelOverride(t *testing.T) {
prev := a.SnapshotModelOverride()
a.SetModelOverride(skillModel)
ours := a.SnapshotModelOverride()
assert.Equal(t, "openai/gpt-4o-mini", a.Model(t.Context()).ID())
assert.Equal(t, "openai/gpt-4o-mini", a.Model(t.Context()).ID().String())

a.RestoreModelOverride(prev, ours)
assert.False(t, a.HasModelOverride())
assert.Equal(t, "openai/gpt-4o", a.Model(t.Context()).ID())
assert.Equal(t, "openai/gpt-4o", a.Model(t.Context()).ID().String())
})

t.Run("restores back to a pre-existing override", func(t *testing.T) {
Expand All @@ -243,10 +244,10 @@ func TestSnapshotAndRestoreModelOverride(t *testing.T) {
prev := a.SnapshotModelOverride()
a.SetModelOverride(skillModel)
ours := a.SnapshotModelOverride()
assert.Equal(t, "openai/gpt-4o-mini", a.Model(t.Context()).ID())
assert.Equal(t, "openai/gpt-4o-mini", a.Model(t.Context()).ID().String())

a.RestoreModelOverride(prev, ours)
assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model(t.Context()).ID())
assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model(t.Context()).ID().String())
})

t.Run("keeps a concurrent change instead of restoring", func(t *testing.T) {
Expand All @@ -266,7 +267,7 @@ func TestSnapshotAndRestoreModelOverride(t *testing.T) {

a.RestoreModelOverride(prev, ours)
require.True(t, a.HasModelOverride(), "user's model choice must be preserved")
assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model(t.Context()).ID())
assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model(t.Context()).ID().String())
})

t.Run("keeps a concurrent clear instead of restoring", func(t *testing.T) {
Expand Down Expand Up @@ -297,8 +298,8 @@ func TestModel_LogsSelection(t *testing.T) {
slog.SetDefault(slog.New(handler))
t.Cleanup(func() { slog.SetDefault(prev) })

model1 := &mockProvider{id: "anthropic/claude-sonnet-4-0"}
model2 := &mockProvider{id: "openai/gpt-4o"}
model1 := &mockProvider{id: modelsdev.NewID("anthropic", "claude-sonnet-4-0")}
model2 := &mockProvider{id: modelsdev.NewID("openai", "gpt-4o")}

a := New("scanner", "test", WithModel(model1), WithModel(model2))

Expand All @@ -308,27 +309,27 @@ func TestModel_LogsSelection(t *testing.T) {

assert.Contains(t, logOutput, "Model selected")
assert.Contains(t, logOutput, "agent=scanner")
assert.Contains(t, logOutput, selected.ID())
assert.Contains(t, logOutput, selected.ID().String())
assert.Contains(t, logOutput, "pool_size=2")

// Verify override scenario logs correct pool_size
buf.Reset()
override := &mockProvider{id: "google/gemini-2.0-flash"}
override := &mockProvider{id: modelsdev.NewID("google", "gemini-2.0-flash")}
a.SetModelOverride(override)

selected = a.Model(t.Context())
logOutput = buf.String()

assert.Equal(t, "google/gemini-2.0-flash", selected.ID())
assert.Equal(t, "google/gemini-2.0-flash", selected.ID().String())
assert.Contains(t, logOutput, "google/gemini-2.0-flash")
assert.Contains(t, logOutput, "pool_size=1")
}

func TestModelOverride_ConcurrentAccess(t *testing.T) {
t.Parallel()

defaultModel := &mockProvider{id: "default"}
overrideModel := &mockProvider{id: "override"}
defaultModel := &mockProvider{id: modelsdev.NewID("default", "x")}
overrideModel := &mockProvider{id: modelsdev.NewID("override", "x")}

a := New("root", "test", WithModel(defaultModel))

Expand Down
2 changes: 1 addition & 1 deletion pkg/config/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func TestParseExamples(t *testing.T) {
continue
}

model, err := modelsStore.GetModel(t.Context(), model.Provider+"/"+model.Model)
model, err := modelsStore.GetModel(t.Context(), modelsdev.NewID(model.Provider, model.Model))
require.NoError(t, err)
require.NotNil(t, model)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/model/provider/anthropic/attachments.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import (
// - application/pdf with InlineData → DocumentBlockParam (base64)
// - text with InlineText → TextBlockParam with TXTEnvelope
// - unsupported / no content → nil (logged as warning)
func convertDocument(ctx context.Context, doc chat.Document, modelID string, store *modelsdev.Store) ([]anthropic.ContentBlockParamUnion, error) {
mc := modelinfo.LoadCaps(store, modelID)
func convertDocument(ctx context.Context, doc chat.Document, id modelsdev.ID, store *modelsdev.Store) ([]anthropic.ContentBlockParamUnion, error) {
mc := modelinfo.LoadCaps(store, id)
return convertDocumentWithCaps(ctx, doc, mc)
}

Expand Down
11 changes: 7 additions & 4 deletions pkg/model/provider/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/docker/docker-agent/pkg/config/latest"
"github.com/docker/docker-agent/pkg/environment"
"github.com/docker/docker-agent/pkg/model/provider/options"
"github.com/docker/docker-agent/pkg/modelsdev"
)

const NoDesktopTokenErrorMessage = "failed to get Docker Desktop token for Gateway. Is Docker Desktop running and are you signed in?"
Expand All @@ -30,11 +31,13 @@ type Config struct {
BaseURL string
}

// ID returns the provider and model ID in the format "provider/model".
// Uses DisplayModel (the original user-configured name) when available,
// ID returns the provider and model identity as a [modelsdev.ID] so
// callers cannot accidentally pass a bare model string where a
// provider-qualified identity is required. The model component uses
// DisplayModel (the original user-configured name) when available,
// falling back to Model (the resolved/pinned name).
func (c *Config) ID() string {
return c.ModelConfig.Provider + "/" + c.ModelConfig.DisplayOrModel()
func (c *Config) ID() modelsdev.ID {
return modelsdev.NewID(c.ModelConfig.Provider, c.ModelConfig.DisplayOrModel())
}

func (c *Config) BaseConfig() Config {
Expand Down
4 changes: 2 additions & 2 deletions pkg/model/provider/bedrock/attachments.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ func imageFormatFromMIME(mimeType string) (types.ImageFormat, bool) {
// - application/pdf with InlineData → ContentBlockMemberDocument (PDF)
// - text/* with InlineText → ContentBlockMemberText with TXTEnvelope
// - unsupported / no content → nil (logged as warning)
func convertDocument(ctx context.Context, doc chat.Document, modelID string, store *modelsdev.Store) ([]types.ContentBlock, error) {
mc := modelinfo.LoadCaps(store, modelID)
func convertDocument(ctx context.Context, doc chat.Document, id modelsdev.ID, store *modelsdev.Store) ([]types.ContentBlock, error) {
mc := modelinfo.LoadCaps(store, id)
return convertDocumentWithCaps(ctx, doc, mc)
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/model/provider/bedrock/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ func detectCachingSupport(ctx context.Context, model string, store *modelsdev.St
return false
}

modelID := "amazon-bedrock/" + model
m, err := store.GetModel(ctx, modelID)
id := modelsdev.NewID("amazon-bedrock", model)
m, err := store.GetModel(ctx, id)
if err != nil {
slog.DebugContext(ctx, "Bedrock prompt caching disabled: model not found in models.dev",
"model_id", modelID, "error", err)
"model_id", id.String(), "error", err)
return false
}

Expand Down
Loading
Loading