diff --git a/pkg/acp/agent_test.go b/pkg/acp/agent_test.go index 1fbfd8815..ab787ae22 100644 --- a/pkg/acp/agent_test.go +++ b/pkg/acp/agent_test.go @@ -14,6 +14,7 @@ import ( "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/config" "github.com/docker/docker-agent/pkg/model/provider/base" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/team" "github.com/docker/docker-agent/pkg/tools" @@ -38,11 +39,11 @@ func (m *mockStream) Close() {} // mockProvider returns a predetermined stream for testing. type mockProvider struct { - id string + id modelsdev.ID stream chat.MessageStream } -func (m *mockProvider) ID() string { return m.id } +func (m *mockProvider) ID() modelsdev.ID { return m.id } func (m *mockProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { return m.stream, nil @@ -85,7 +86,7 @@ func TestACPSessionPersistence(t *testing.T) { }, }, } - prov := &mockProvider{id: "test/mock-model", stream: stream} + prov := &mockProvider{id: modelsdev.NewID("test", "mock-model"), stream: stream} // Create a minimal team with a root agent root := agent.New("root", "You are a test agent", agent.WithModel(prov)) diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index f4c236a12..d50a6c577 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -195,7 +195,7 @@ func (a *Agent) SetModelOverride(models ...provider.Provider) ModelOverrideSnaps a.modelOverrides.Store(ptr) ids := make([]string, len(validModels)) for i, m := range validModels { - ids[i] = m.ID() + ids[i] = m.ID().String() } slog.Debug("Set model override", "agent", a.name, "models", ids) } diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index 2edb48f1a..556c82bd0 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -14,6 +14,7 @@ import ( "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/concurrent" "github.com/docker/docker-agent/pkg/model/provider/base" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/tools" ) @@ -127,10 +128,10 @@ func TestAgentTools(t *testing.T) { // mockProvider implements provider.Provider for testing type mockProvider struct { - id string + id modelsdev.ID } -func (m *mockProvider) ID() string { return m.id } +func (m *mockProvider) ID() modelsdev.ID { return m.id } func (m *mockProvider) CreateChatCompletionStream(_ context.Context, _ []chat.Message, _ []tools.Tool) (chat.MessageStream, error) { return nil, nil } @@ -139,29 +140,29 @@ func (m *mockProvider) BaseConfig() base.Config { return base.Config{} } func TestModelOverride(t *testing.T) { t.Parallel() - defaultModel := &mockProvider{id: "openai/gpt-4o"} - overrideModel := &mockProvider{id: "anthropic/claude-sonnet-4-0"} + defaultModel := &mockProvider{id: modelsdev.NewID("openai", "gpt-4o")} + overrideModel := &mockProvider{id: modelsdev.NewID("anthropic", "claude-sonnet-4-0")} a := New("root", "test", WithModel(defaultModel)) // Initially should return the default model - assert.Equal(t, "openai/gpt-4o", a.Model(t.Context()).ID()) + assert.Equal(t, "openai/gpt-4o", a.Model(t.Context()).ID().String()) assert.False(t, a.HasModelOverride()) // Set an override a.SetModelOverride(overrideModel) assert.True(t, a.HasModelOverride()) - assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model(t.Context()).ID()) + assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model(t.Context()).ID().String()) // ConfiguredModels still reflects the originally configured models configuredModels := a.ConfiguredModels() require.Len(t, configuredModels, 1) - assert.Equal(t, "openai/gpt-4o", configuredModels[0].ID()) + assert.Equal(t, "openai/gpt-4o", configuredModels[0].ID().String()) // Clear the override a.SetModelOverride(nil) assert.False(t, a.HasModelOverride()) - assert.Equal(t, "openai/gpt-4o", a.Model(t.Context()).ID()) + assert.Equal(t, "openai/gpt-4o", a.Model(t.Context()).ID().String()) } func TestSetModelOverride_ReturnsSnapshotOfStoredValue(t *testing.T) { @@ -174,9 +175,9 @@ func TestSetModelOverride_ReturnsSnapshotOfStoredValue(t *testing.T) { // change wins) instead of incorrectly succeeding. t.Parallel() - defaultModel := &mockProvider{id: "default"} - oursModel := &mockProvider{id: "ours"} - othersModel := &mockProvider{id: "others"} + defaultModel := &mockProvider{id: modelsdev.NewID("default", "x")} + oursModel := &mockProvider{id: modelsdev.NewID("ours", "x")} + othersModel := &mockProvider{id: modelsdev.NewID("others", "x")} a := New("root", "test", WithModel(defaultModel)) @@ -187,19 +188,19 @@ func TestSetModelOverride_ReturnsSnapshotOfStoredValue(t *testing.T) { // Simulate a concurrent caller storing a different override _after_ we // stored ours but _before_ a hypothetical post-store SnapshotModelOverride. a.SetModelOverride(othersModel) - require.Equal(t, "others", a.Model(t.Context()).ID()) + require.Equal(t, "others/x", a.Model(t.Context()).ID().String()) // The deferred restore must be a no-op because oursSnap holds the // pointer we stored, not the current pointer. a.RestoreModelOverride(prev, oursSnap) - assert.Equal(t, "others", a.Model(t.Context()).ID(), + assert.Equal(t, "others/x", a.Model(t.Context()).ID().String(), "concurrent override must be preserved; the snapshot returned by SetModelOverride captures the stored pointer") } func TestSetModelOverride_ClearReturnsZeroSnapshot(t *testing.T) { t.Parallel() - a := New("root", "test", WithModel(&mockProvider{id: "default"})) + a := New("root", "test", WithModel(&mockProvider{id: modelsdev.NewID("default", "x")})) // Calling SetModelOverride with no providers (or nil) clears the override. // The returned snapshot should round-trip cleanly through RestoreModelOverride. @@ -207,7 +208,7 @@ func TestSetModelOverride_ClearReturnsZeroSnapshot(t *testing.T) { assert.False(t, a.HasModelOverride()) // Now set an override and restore using `cleared` as `prev`. - oursSnap := a.SetModelOverride(&mockProvider{id: "ours"}) + oursSnap := a.SetModelOverride(&mockProvider{id: modelsdev.NewID("ours", "x")}) require.True(t, a.HasModelOverride()) a.RestoreModelOverride(cleared, oursSnap) @@ -217,9 +218,9 @@ func TestSetModelOverride_ClearReturnsZeroSnapshot(t *testing.T) { func TestSnapshotAndRestoreModelOverride(t *testing.T) { t.Parallel() - defaultModel := &mockProvider{id: "openai/gpt-4o"} - skillModel := &mockProvider{id: "openai/gpt-4o-mini"} - userModel := &mockProvider{id: "anthropic/claude-sonnet-4-0"} + defaultModel := &mockProvider{id: modelsdev.NewID("openai", "gpt-4o")} + skillModel := &mockProvider{id: modelsdev.NewID("openai", "gpt-4o-mini")} + userModel := &mockProvider{id: modelsdev.NewID("anthropic", "claude-sonnet-4-0")} t.Run("restores when no concurrent change", func(t *testing.T) { t.Parallel() @@ -228,11 +229,11 @@ func TestSnapshotAndRestoreModelOverride(t *testing.T) { prev := a.SnapshotModelOverride() a.SetModelOverride(skillModel) ours := a.SnapshotModelOverride() - assert.Equal(t, "openai/gpt-4o-mini", a.Model(t.Context()).ID()) + assert.Equal(t, "openai/gpt-4o-mini", a.Model(t.Context()).ID().String()) a.RestoreModelOverride(prev, ours) assert.False(t, a.HasModelOverride()) - assert.Equal(t, "openai/gpt-4o", a.Model(t.Context()).ID()) + assert.Equal(t, "openai/gpt-4o", a.Model(t.Context()).ID().String()) }) t.Run("restores back to a pre-existing override", func(t *testing.T) { @@ -243,10 +244,10 @@ func TestSnapshotAndRestoreModelOverride(t *testing.T) { prev := a.SnapshotModelOverride() a.SetModelOverride(skillModel) ours := a.SnapshotModelOverride() - assert.Equal(t, "openai/gpt-4o-mini", a.Model(t.Context()).ID()) + assert.Equal(t, "openai/gpt-4o-mini", a.Model(t.Context()).ID().String()) a.RestoreModelOverride(prev, ours) - assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model(t.Context()).ID()) + assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model(t.Context()).ID().String()) }) t.Run("keeps a concurrent change instead of restoring", func(t *testing.T) { @@ -266,7 +267,7 @@ func TestSnapshotAndRestoreModelOverride(t *testing.T) { a.RestoreModelOverride(prev, ours) require.True(t, a.HasModelOverride(), "user's model choice must be preserved") - assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model(t.Context()).ID()) + assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model(t.Context()).ID().String()) }) t.Run("keeps a concurrent clear instead of restoring", func(t *testing.T) { @@ -297,8 +298,8 @@ func TestModel_LogsSelection(t *testing.T) { slog.SetDefault(slog.New(handler)) t.Cleanup(func() { slog.SetDefault(prev) }) - model1 := &mockProvider{id: "anthropic/claude-sonnet-4-0"} - model2 := &mockProvider{id: "openai/gpt-4o"} + model1 := &mockProvider{id: modelsdev.NewID("anthropic", "claude-sonnet-4-0")} + model2 := &mockProvider{id: modelsdev.NewID("openai", "gpt-4o")} a := New("scanner", "test", WithModel(model1), WithModel(model2)) @@ -308,18 +309,18 @@ func TestModel_LogsSelection(t *testing.T) { assert.Contains(t, logOutput, "Model selected") assert.Contains(t, logOutput, "agent=scanner") - assert.Contains(t, logOutput, selected.ID()) + assert.Contains(t, logOutput, selected.ID().String()) assert.Contains(t, logOutput, "pool_size=2") // Verify override scenario logs correct pool_size buf.Reset() - override := &mockProvider{id: "google/gemini-2.0-flash"} + override := &mockProvider{id: modelsdev.NewID("google", "gemini-2.0-flash")} a.SetModelOverride(override) selected = a.Model(t.Context()) logOutput = buf.String() - assert.Equal(t, "google/gemini-2.0-flash", selected.ID()) + assert.Equal(t, "google/gemini-2.0-flash", selected.ID().String()) assert.Contains(t, logOutput, "google/gemini-2.0-flash") assert.Contains(t, logOutput, "pool_size=1") } @@ -327,8 +328,8 @@ func TestModel_LogsSelection(t *testing.T) { func TestModelOverride_ConcurrentAccess(t *testing.T) { t.Parallel() - defaultModel := &mockProvider{id: "default"} - overrideModel := &mockProvider{id: "override"} + defaultModel := &mockProvider{id: modelsdev.NewID("default", "x")} + overrideModel := &mockProvider{id: modelsdev.NewID("override", "x")} a := New("root", "test", WithModel(defaultModel)) diff --git a/pkg/config/examples_test.go b/pkg/config/examples_test.go index 5a9aee6ac..a31798a30 100644 --- a/pkg/config/examples_test.go +++ b/pkg/config/examples_test.go @@ -73,7 +73,7 @@ func TestParseExamples(t *testing.T) { continue } - model, err := modelsStore.GetModel(t.Context(), model.Provider+"/"+model.Model) + model, err := modelsStore.GetModel(t.Context(), modelsdev.NewID(model.Provider, model.Model)) require.NoError(t, err) require.NotNil(t, model) } diff --git a/pkg/model/provider/anthropic/attachments.go b/pkg/model/provider/anthropic/attachments.go index 04f5bf838..1583dbcfb 100644 --- a/pkg/model/provider/anthropic/attachments.go +++ b/pkg/model/provider/anthropic/attachments.go @@ -23,8 +23,8 @@ import ( // - application/pdf with InlineData → DocumentBlockParam (base64) // - text with InlineText → TextBlockParam with TXTEnvelope // - unsupported / no content → nil (logged as warning) -func convertDocument(ctx context.Context, doc chat.Document, modelID string, store *modelsdev.Store) ([]anthropic.ContentBlockParamUnion, error) { - mc := modelinfo.LoadCaps(store, modelID) +func convertDocument(ctx context.Context, doc chat.Document, id modelsdev.ID, store *modelsdev.Store) ([]anthropic.ContentBlockParamUnion, error) { + mc := modelinfo.LoadCaps(store, id) return convertDocumentWithCaps(ctx, doc, mc) } diff --git a/pkg/model/provider/base/base.go b/pkg/model/provider/base/base.go index b757590b1..7a4eb61e3 100644 --- a/pkg/model/provider/base/base.go +++ b/pkg/model/provider/base/base.go @@ -4,6 +4,7 @@ import ( "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/modelsdev" ) const NoDesktopTokenErrorMessage = "failed to get Docker Desktop token for Gateway. Is Docker Desktop running and are you signed in?" @@ -30,11 +31,13 @@ type Config struct { BaseURL string } -// ID returns the provider and model ID in the format "provider/model". -// Uses DisplayModel (the original user-configured name) when available, +// ID returns the provider and model identity as a [modelsdev.ID] so +// callers cannot accidentally pass a bare model string where a +// provider-qualified identity is required. The model component uses +// DisplayModel (the original user-configured name) when available, // falling back to Model (the resolved/pinned name). -func (c *Config) ID() string { - return c.ModelConfig.Provider + "/" + c.ModelConfig.DisplayOrModel() +func (c *Config) ID() modelsdev.ID { + return modelsdev.NewID(c.ModelConfig.Provider, c.ModelConfig.DisplayOrModel()) } func (c *Config) BaseConfig() Config { diff --git a/pkg/model/provider/bedrock/attachments.go b/pkg/model/provider/bedrock/attachments.go index 03ea576f6..b6f1e6f93 100644 --- a/pkg/model/provider/bedrock/attachments.go +++ b/pkg/model/provider/bedrock/attachments.go @@ -41,8 +41,8 @@ func imageFormatFromMIME(mimeType string) (types.ImageFormat, bool) { // - application/pdf with InlineData → ContentBlockMemberDocument (PDF) // - text/* with InlineText → ContentBlockMemberText with TXTEnvelope // - unsupported / no content → nil (logged as warning) -func convertDocument(ctx context.Context, doc chat.Document, modelID string, store *modelsdev.Store) ([]types.ContentBlock, error) { - mc := modelinfo.LoadCaps(store, modelID) +func convertDocument(ctx context.Context, doc chat.Document, id modelsdev.ID, store *modelsdev.Store) ([]types.ContentBlock, error) { + mc := modelinfo.LoadCaps(store, id) return convertDocumentWithCaps(ctx, doc, mc) } diff --git a/pkg/model/provider/bedrock/client.go b/pkg/model/provider/bedrock/client.go index 0643768e5..282ef56f2 100644 --- a/pkg/model/provider/bedrock/client.go +++ b/pkg/model/provider/bedrock/client.go @@ -140,11 +140,11 @@ func detectCachingSupport(ctx context.Context, model string, store *modelsdev.St return false } - modelID := "amazon-bedrock/" + model - m, err := store.GetModel(ctx, modelID) + id := modelsdev.NewID("amazon-bedrock", model) + m, err := store.GetModel(ctx, id) if err != nil { slog.DebugContext(ctx, "Bedrock prompt caching disabled: model not found in models.dev", - "model_id", modelID, "error", err) + "model_id", id.String(), "error", err) return false } diff --git a/pkg/model/provider/bedrock/client_test.go b/pkg/model/provider/bedrock/client_test.go index 9d0e09e5d..fda3fc397 100644 --- a/pkg/model/provider/bedrock/client_test.go +++ b/pkg/model/provider/bedrock/client_test.go @@ -26,7 +26,7 @@ func TestConvertMessages_UserText(t *testing.T) { Content: "Hello, world!", }} - bedrockMsgs, system := convertMessages(t.Context(), msgs, "", modelsdev.NewDatabaseStore(&modelsdev.Database{}), false) + bedrockMsgs, system := convertMessages(t.Context(), msgs, modelsdev.ID{}, modelsdev.NewDatabaseStore(&modelsdev.Database{}), false) require.Len(t, bedrockMsgs, 1) assert.Empty(t, system) @@ -46,7 +46,7 @@ func TestConvertMessages_SystemExtraction(t *testing.T) { {Role: chat.MessageRoleUser, Content: "Hi"}, } - bedrockMsgs, system := convertMessages(t.Context(), msgs, "", modelsdev.NewDatabaseStore(&modelsdev.Database{}), false) + bedrockMsgs, system := convertMessages(t.Context(), msgs, modelsdev.ID{}, modelsdev.NewDatabaseStore(&modelsdev.Database{}), false) require.Len(t, bedrockMsgs, 1) // Only user message require.Len(t, system, 1) // System extracted @@ -71,7 +71,7 @@ func TestConvertMessages_AssistantWithToolCalls(t *testing.T) { }}, }} - bedrockMsgs, _ := convertMessages(t.Context(), msgs, "", modelsdev.NewDatabaseStore(&modelsdev.Database{}), false) + bedrockMsgs, _ := convertMessages(t.Context(), msgs, modelsdev.ID{}, modelsdev.NewDatabaseStore(&modelsdev.Database{}), false) require.Len(t, bedrockMsgs, 1) require.Len(t, bedrockMsgs[0].Content, 1) @@ -92,7 +92,7 @@ func TestConvertMessages_ToolResult(t *testing.T) { Content: "Weather is sunny", }} - bedrockMsgs, _ := convertMessages(t.Context(), msgs, "", modelsdev.NewDatabaseStore(&modelsdev.Database{}), false) + bedrockMsgs, _ := convertMessages(t.Context(), msgs, modelsdev.ID{}, modelsdev.NewDatabaseStore(&modelsdev.Database{}), false) require.Len(t, bedrockMsgs, 1) assert.Equal(t, types.ConversationRoleUser, bedrockMsgs[0].Role) @@ -116,7 +116,7 @@ func TestConvertMessages_EmptyContent(t *testing.T) { {Role: chat.MessageRoleUser, Content: " "}, } - bedrockMsgs, _ := convertMessages(t.Context(), msgs, "", modelsdev.NewDatabaseStore(&modelsdev.Database{}), false) + bedrockMsgs, _ := convertMessages(t.Context(), msgs, modelsdev.ID{}, modelsdev.NewDatabaseStore(&modelsdev.Database{}), false) // Both messages now produce user turns with empty or whitespace content blocks. assert.Len(t, bedrockMsgs, 2) } @@ -182,7 +182,7 @@ func TestConvertMessages_MultiContent(t *testing.T) { }, }} - bedrockMsgs, _ := convertMessages(t.Context(), msgs, "", modelsdev.NewDatabaseStore(&modelsdev.Database{}), false) + bedrockMsgs, _ := convertMessages(t.Context(), msgs, modelsdev.ID{}, modelsdev.NewDatabaseStore(&modelsdev.Database{}), false) require.Len(t, bedrockMsgs, 1) require.Len(t, bedrockMsgs[0].Content, 2) @@ -206,7 +206,7 @@ func TestConvertMessages_ConsecutiveToolResults(t *testing.T) { {Role: chat.MessageRoleUser, Content: "Continue"}, } - bedrockMsgs, _ := convertMessages(t.Context(), msgs, "", modelsdev.NewDatabaseStore(&modelsdev.Database{}), false) + bedrockMsgs, _ := convertMessages(t.Context(), msgs, modelsdev.ID{}, modelsdev.NewDatabaseStore(&modelsdev.Database{}), false) // Expect: user, assistant, user (grouped tool results), user require.Len(t, bedrockMsgs, 4) @@ -1137,7 +1137,7 @@ func TestConvertMessages_WithCaching(t *testing.T) { {Role: chat.MessageRoleUser, Content: "How are you?"}, } - bedrockMsgs, system := convertMessages(t.Context(), msgs, "", modelsdev.NewDatabaseStore(&modelsdev.Database{}), true) + bedrockMsgs, system := convertMessages(t.Context(), msgs, modelsdev.ID{}, modelsdev.NewDatabaseStore(&modelsdev.Database{}), true) // System should have text block + cache point require.Len(t, system, 2) @@ -1168,7 +1168,7 @@ func TestConvertMessages_WithoutCaching(t *testing.T) { {Role: chat.MessageRoleUser, Content: "Hello"}, } - bedrockMsgs, system := convertMessages(t.Context(), msgs, "", modelsdev.NewDatabaseStore(&modelsdev.Database{}), false) + bedrockMsgs, system := convertMessages(t.Context(), msgs, modelsdev.ID{}, modelsdev.NewDatabaseStore(&modelsdev.Database{}), false) // System should only have text block, no cache point require.Len(t, system, 1) @@ -1268,7 +1268,7 @@ func TestConvertMessages_EmptyWithCaching(t *testing.T) { t.Parallel() // Empty message list should not panic with caching enabled - bedrockMsgs, system := convertMessages(t.Context(), []chat.Message{}, "", modelsdev.NewDatabaseStore(&modelsdev.Database{}), true) + bedrockMsgs, system := convertMessages(t.Context(), []chat.Message{}, modelsdev.ID{}, modelsdev.NewDatabaseStore(&modelsdev.Database{}), true) assert.Empty(t, bedrockMsgs) assert.Empty(t, system) @@ -1281,7 +1281,7 @@ func TestConvertMessages_SingleMessageWithCaching(t *testing.T) { {Role: chat.MessageRoleUser, Content: "Hello"}, } - bedrockMsgs, _ := convertMessages(t.Context(), msgs, "", modelsdev.NewDatabaseStore(&modelsdev.Database{}), true) + bedrockMsgs, _ := convertMessages(t.Context(), msgs, modelsdev.ID{}, modelsdev.NewDatabaseStore(&modelsdev.Database{}), true) require.Len(t, bedrockMsgs, 1) // Single message should get a cache point appended @@ -1301,7 +1301,7 @@ func TestConvertMessages_MultiContentWithCaching(t *testing.T) { }, }} - bedrockMsgs, _ := convertMessages(t.Context(), msgs, "", modelsdev.NewDatabaseStore(&modelsdev.Database{}), true) + bedrockMsgs, _ := convertMessages(t.Context(), msgs, modelsdev.ID{}, modelsdev.NewDatabaseStore(&modelsdev.Database{}), true) require.Len(t, bedrockMsgs, 1) // 2 text blocks + cache point = 3 content blocks @@ -1324,7 +1324,7 @@ func TestConvertMessages_ToolResultWithCaching(t *testing.T) { {Role: chat.MessageRoleTool, ToolCallID: "tool-1", Content: "Result"}, } - bedrockMsgs, _ := convertMessages(t.Context(), msgs, "", modelsdev.NewDatabaseStore(&modelsdev.Database{}), true) + bedrockMsgs, _ := convertMessages(t.Context(), msgs, modelsdev.ID{}, modelsdev.NewDatabaseStore(&modelsdev.Database{}), true) // Expect: user, assistant, user (tool result) require.Len(t, bedrockMsgs, 3) diff --git a/pkg/model/provider/bedrock/convert.go b/pkg/model/provider/bedrock/convert.go index 73c302ee1..aa23b4abd 100644 --- a/pkg/model/provider/bedrock/convert.go +++ b/pkg/model/provider/bedrock/convert.go @@ -19,7 +19,7 @@ import ( // convertMessages handles Bedrock's Converse API constraints: // - Tool results must immediately follow the assistant message with tool_use // - Multiple consecutive tool results must be grouped into a single user message -func convertMessages(ctx context.Context, messages []chat.Message, modelID string, store *modelsdev.Store, enableCaching bool) ([]types.Message, []types.SystemContentBlock) { +func convertMessages(ctx context.Context, messages []chat.Message, id modelsdev.ID, store *modelsdev.Store, enableCaching bool) ([]types.Message, []types.SystemContentBlock) { var bedrockMessages []types.Message var systemBlocks []types.SystemContentBlock @@ -44,7 +44,7 @@ func convertMessages(ctx context.Context, messages []chat.Message, modelID strin } case chat.MessageRoleUser: - contentBlocks := convertUserContent(ctx, msg, modelID, store) + contentBlocks := convertUserContent(ctx, msg, id, store) if len(contentBlocks) > 0 { bedrockMessages = append(bedrockMessages, types.Message{ Role: types.ConversationRoleUser, @@ -121,7 +121,7 @@ func applyCachePointsToMessages(messages []types.Message) { } } -func convertUserContent(ctx context.Context, msg *chat.Message, modelID string, store *modelsdev.Store) []types.ContentBlock { +func convertUserContent(ctx context.Context, msg *chat.Message, id modelsdev.ID, store *modelsdev.Store) []types.ContentBlock { var blocks []types.ContentBlock if len(msg.MultiContent) > 0 { @@ -140,7 +140,7 @@ func convertUserContent(ctx context.Context, msg *chat.Message, modelID string, } case chat.MessagePartTypeDocument: if part.Document != nil { - docBlocks, err := convertDocument(ctx, *part.Document, modelID, store) + docBlocks, err := convertDocument(ctx, *part.Document, id, store) if err != nil { slog.WarnContext(ctx, "failed to convert document attachment", "error", err, "doc", part.Document.Name) continue diff --git a/pkg/model/provider/clone_merge_test.go b/pkg/model/provider/clone_merge_test.go index 8ae8796b8..944053730 100644 --- a/pkg/model/provider/clone_merge_test.go +++ b/pkg/model/provider/clone_merge_test.go @@ -9,6 +9,7 @@ import ( "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/modelsdev" ) func TestMergeCloneOptions_NoOverrides(t *testing.T) { @@ -120,7 +121,7 @@ func TestMergeCloneOptions_LaterOverridesWin(t *testing.T) { func TestCloneWithOptions_FallbackOnError(t *testing.T) { // fakeProvider returns a zero-valued base.Config, so its Provider type is // empty; that always fails the factory-registry lookup in createDirectProvider. - original := &fakeProvider{id: "original"} + original := &fakeProvider{id: modelsdev.NewID("test", "original")} got := CloneWithOptions(t.Context(), original, options.WithMaxTokens(int64(2048))) diff --git a/pkg/model/provider/factory_test.go b/pkg/model/provider/factory_test.go index 3f849f786..9da027353 100644 --- a/pkg/model/provider/factory_test.go +++ b/pkg/model/provider/factory_test.go @@ -13,15 +13,16 @@ import ( "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/tools" ) // fakeProvider is a Provider stub used to verify factory dispatch. type fakeProvider struct { - id string + id modelsdev.ID } -func (f *fakeProvider) ID() string { return f.id } +func (f *fakeProvider) ID() modelsdev.ID { return f.id } func (f *fakeProvider) CreateChatCompletionStream(_ context.Context, _ []chat.Message, _ []tools.Tool) (chat.MessageStream, error) { return nil, errors.New("not implemented") } @@ -39,7 +40,7 @@ func withFactories(t *testing.T, factories map[string]providerFactory) { func tagFactory(id string) providerFactory { return func(_ context.Context, _ *latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) { - return &fakeProvider{id: id}, nil + return &fakeProvider{id: modelsdev.NewID("test", id)}, nil } } @@ -110,7 +111,7 @@ func TestCreateDirectProvider_DispatchByType(t *testing.T) { require.NoError(t, err) fp, ok := p.(*fakeProvider) require.True(t, ok, "expected fakeProvider, got %T", p) - assert.Equal(t, tt.expectID, fp.id) + assert.Equal(t, tt.expectID, fp.id.Model) }) } } @@ -152,7 +153,7 @@ func TestCreateDirectProvider_AppliesProviderDefaults(t *testing.T) { withFactories(t, map[string]providerFactory{ "openai_chatcompletions": func(_ context.Context, cfg *latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) { got = cfg - return &fakeProvider{id: "captured"}, nil + return &fakeProvider{id: modelsdev.NewID("test", "captured")}, nil }, }) diff --git a/pkg/model/provider/gemini/attachments.go b/pkg/model/provider/gemini/attachments.go index b996be577..3370063a4 100644 --- a/pkg/model/provider/gemini/attachments.go +++ b/pkg/model/provider/gemini/attachments.go @@ -20,8 +20,8 @@ import ( // - image/* or binary with InlineData → genai.Blob part // - text MIMEs with InlineText → genai.Text part with TXTEnvelope // - unsupported / no content → nil (logged as warning) -func convertDocument(ctx context.Context, doc chat.Document, modelID string, store *modelsdev.Store) (*genai.Part, error) { - mc := modelinfo.LoadCaps(store, modelID) +func convertDocument(ctx context.Context, doc chat.Document, id modelsdev.ID, store *modelsdev.Store) (*genai.Part, error) { + mc := modelinfo.LoadCaps(store, id) return convertDocumentWithCaps(ctx, doc, mc) } diff --git a/pkg/model/provider/gemini/client.go b/pkg/model/provider/gemini/client.go index 5df5241d9..21e145361 100644 --- a/pkg/model/provider/gemini/client.go +++ b/pkg/model/provider/gemini/client.go @@ -194,7 +194,7 @@ func thoughtSignatureOrDefault(sig []byte) []byte { } // convertMessagesToGemini converts chat.Messages into Gemini Contents -func convertMessagesToGemini(ctx context.Context, messages []chat.Message, modelID string, store *modelsdev.Store) []*genai.Content { +func convertMessagesToGemini(ctx context.Context, messages []chat.Message, id modelsdev.ID, store *modelsdev.Store) []*genai.Content { contents := make([]*genai.Content, 0, len(messages)) for i := range messages { msg := &messages[i] @@ -259,7 +259,7 @@ func convertMessagesToGemini(ctx context.Context, messages []chat.Message, model // Handle regular messages if len(msg.MultiContent) > 0 { - parts := convertMultiContent(ctx, msg.MultiContent, msg.ThoughtSignature, modelID, store) + parts := convertMultiContent(ctx, msg.MultiContent, msg.ThoughtSignature, id, store) if len(parts) > 0 { contents = append(contents, genai.NewContentFromParts(parts, role)) } @@ -289,7 +289,7 @@ func newTextPartWithSignature(text string, signature []byte) *genai.Part { } // convertMultiContent converts multi-part content to Gemini parts -func convertMultiContent(ctx context.Context, multiContent []chat.MessagePart, thoughtSignature []byte, modelID string, store *modelsdev.Store) []*genai.Part { +func convertMultiContent(ctx context.Context, multiContent []chat.MessagePart, thoughtSignature []byte, id modelsdev.ID, store *modelsdev.Store) []*genai.Part { parts := make([]*genai.Part, 0, len(multiContent)) for _, part := range multiContent { switch part.Type { @@ -302,7 +302,7 @@ func convertMultiContent(ctx context.Context, multiContent []chat.MessagePart, t } case chat.MessagePartTypeDocument: if part.Document != nil { - docPart, err := convertDocument(ctx, *part.Document, modelID, store) + docPart, err := convertDocument(ctx, *part.Document, id, store) if err != nil { slog.WarnContext(ctx, "failed to convert document attachment", "error", err, "doc", part.Document.Name) continue diff --git a/pkg/model/provider/gemini/client_test.go b/pkg/model/provider/gemini/client_test.go index a8e83e1c5..d1e84bf00 100644 --- a/pkg/model/provider/gemini/client_test.go +++ b/pkg/model/provider/gemini/client_test.go @@ -366,7 +366,7 @@ func TestConvertMessagesToGemini_ThoughtSignature(t *testing.T) { contents := convertMessagesToGemini(t.Context(), []chat.Message{ {Role: chat.MessageRoleUser, Content: "go"}, tt.message, - }, "", modelsdev.NewDatabaseStore(&modelsdev.Database{})) + }, modelsdev.ID{}, modelsdev.NewDatabaseStore(&modelsdev.Database{})) require.Len(t, contents, 2) assistant := contents[1] diff --git a/pkg/model/provider/google_factory_test.go b/pkg/model/provider/google_factory_test.go index d531a18e9..6a1cc5eeb 100644 --- a/pkg/model/provider/google_factory_test.go +++ b/pkg/model/provider/google_factory_test.go @@ -43,7 +43,7 @@ func TestGoogleFactory_RoutesGeminiByDefault(t *testing.T) { require.NoError(t, err) fp, ok := p.(*fakeProvider) require.True(t, ok) - assert.Equal(t, "gemini", fp.id) + assert.Equal(t, "gemini", fp.id.Model) } // TestGoogleFactory_RoutesGeminiWhenPublisherIsGoogle covers the documented @@ -68,7 +68,7 @@ func TestGoogleFactory_RoutesGeminiWhenPublisherIsGoogle(t *testing.T) { require.NoError(t, err) fp, ok := p.(*fakeProvider) require.True(t, ok) - assert.Equal(t, "gemini", fp.id) + assert.Equal(t, "gemini", fp.id.Model) } // TestGoogleFactory_RoutesVertexForModelGarden verifies that any non-Google @@ -92,7 +92,7 @@ func TestGoogleFactory_RoutesVertexForModelGarden(t *testing.T) { require.NoError(t, err) fp, ok := p.(*fakeProvider) require.True(t, ok) - assert.Equal(t, "vertex", fp.id) + assert.Equal(t, "vertex", fp.id.Model) } // TestGoogleFactory_PropagatesGeminiError verifies that errors from the inner diff --git a/pkg/model/provider/oaistream/attachments.go b/pkg/model/provider/oaistream/attachments.go index 0475a2d13..4f4c124fd 100644 --- a/pkg/model/provider/oaistream/attachments.go +++ b/pkg/model/provider/oaistream/attachments.go @@ -24,8 +24,8 @@ import ( // - other binary MIMEs with InlineData → drop (no native document block on Chat Completions) // - text MIMEs with InlineText → text part with TXTEnvelope // - unsupported / no content → nil (logged as warning) -func convertDocument(ctx context.Context, doc chat.Document, modelID string, store *modelsdev.Store) ([]openai.ChatCompletionContentPartUnionParam, error) { - mc := modelinfo.LoadCaps(store, modelID) +func convertDocument(ctx context.Context, doc chat.Document, id modelsdev.ID, store *modelsdev.Store) ([]openai.ChatCompletionContentPartUnionParam, error) { + mc := modelinfo.LoadCaps(store, id) return convertDocumentWithCaps(ctx, doc, mc) } diff --git a/pkg/model/provider/oaistream/attachments_test.go b/pkg/model/provider/oaistream/attachments_test.go index a05c24819..c6f8ab328 100644 --- a/pkg/model/provider/oaistream/attachments_test.go +++ b/pkg/model/provider/oaistream/attachments_test.go @@ -85,11 +85,11 @@ func TestConvertDocument_QualifiedIDRequired(t *testing.T) { }} // Bare model name (the original bug): image must be dropped. - partsBare := ConvertMultiContent(t.Context(), msgParts, "gpt-4o", store) + partsBare := ConvertMultiContent(t.Context(), msgParts, modelsdev.NewID("", "gpt-4o"), store) assert.Empty(t, partsBare, "bare model name must not resolve caps: image should be dropped") // Qualified ID (the fix, matching what c.ID() returns): image must be preserved. - partsQualified := ConvertMultiContent(t.Context(), msgParts, "openai/gpt-4o", store) + partsQualified := ConvertMultiContent(t.Context(), msgParts, modelsdev.NewID("openai", "gpt-4o"), store) require.Len(t, partsQualified, 1, "qualified ID must resolve caps: image should be present") assert.NotNil(t, partsQualified[0].OfImageURL, "expected image URL part for qualified model ID") } diff --git a/pkg/model/provider/oaistream/messages.go b/pkg/model/provider/oaistream/messages.go index 105ec4e15..5eef0870e 100644 --- a/pkg/model/provider/oaistream/messages.go +++ b/pkg/model/provider/oaistream/messages.go @@ -28,17 +28,17 @@ func (j JSONSchema) MarshalJSON() ([]byte, error) { // ConvertMultiContent converts chat.MessagePart slices to OpenAI content // parts using the provided modelsdev.Store for capability lookups. -func ConvertMultiContent(ctx context.Context, multiContent []chat.MessagePart, modelID string, store *modelsdev.Store) []openai.ChatCompletionContentPartUnionParam { - return convertMultiContentWithStore(ctx, multiContent, modelID, store) +func ConvertMultiContent(ctx context.Context, multiContent []chat.MessagePart, id modelsdev.ID, store *modelsdev.Store) []openai.ChatCompletionContentPartUnionParam { + return convertMultiContentWithStore(ctx, multiContent, id, store) } // ConvertMessages converts chat.Message slices to OpenAI message params // using the provided modelsdev.Store for capability lookups. -func ConvertMessages(ctx context.Context, messages []chat.Message, modelID string, store *modelsdev.Store) []openai.ChatCompletionMessageParamUnion { - return convertMessagesWithStore(ctx, messages, modelID, store) +func ConvertMessages(ctx context.Context, messages []chat.Message, id modelsdev.ID, store *modelsdev.Store) []openai.ChatCompletionMessageParamUnion { + return convertMessagesWithStore(ctx, messages, id, store) } -func convertMultiContentWithStore(ctx context.Context, multiContent []chat.MessagePart, modelID string, store *modelsdev.Store) []openai.ChatCompletionContentPartUnionParam { +func convertMultiContentWithStore(ctx context.Context, multiContent []chat.MessagePart, id modelsdev.ID, store *modelsdev.Store) []openai.ChatCompletionContentPartUnionParam { parts := make([]openai.ChatCompletionContentPartUnionParam, 0, len(multiContent)) for _, part := range multiContent { switch part.Type { @@ -54,7 +54,7 @@ func convertMultiContentWithStore(ctx context.Context, multiContent []chat.Messa } case chat.MessagePartTypeDocument: if part.Document != nil { - docParts, err := convertDocument(ctx, *part.Document, modelID, store) + docParts, err := convertDocument(ctx, *part.Document, id, store) if err != nil { slog.WarnContext(ctx, "failed to convert document attachment", "error", err, "doc", part.Document.Name) continue @@ -66,7 +66,7 @@ func convertMultiContentWithStore(ctx context.Context, multiContent []chat.Messa return parts } -func convertMessagesWithStore(ctx context.Context, messages []chat.Message, modelID string, store *modelsdev.Store) []openai.ChatCompletionMessageParamUnion { +func convertMessagesWithStore(ctx context.Context, messages []chat.Message, id modelsdev.ID, store *modelsdev.Store) []openai.ChatCompletionMessageParamUnion { openaiMessages := make([]openai.ChatCompletionMessageParamUnion, 0, len(messages)) for i := range messages { msg := &messages[i] @@ -99,7 +99,7 @@ func convertMessagesWithStore(ctx context.Context, messages []chat.Message, mode if len(msg.MultiContent) == 0 { openaiMessage = openai.UserMessage(msg.Content) } else { - openaiMessage = openai.UserMessage(convertMultiContentWithStore(ctx, msg.MultiContent, modelID, store)) + openaiMessage = openai.UserMessage(convertMultiContentWithStore(ctx, msg.MultiContent, id, store)) } case chat.MessageRoleAssistant: diff --git a/pkg/model/provider/oaistream/messages_test.go b/pkg/model/provider/oaistream/messages_test.go index 876aa96df..4c0d6404c 100644 --- a/pkg/model/provider/oaistream/messages_test.go +++ b/pkg/model/provider/oaistream/messages_test.go @@ -63,7 +63,7 @@ func TestConvertMultiContent(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - result := ConvertMultiContent(t.Context(), tt.multiContent, "", modelsdev.NewDatabaseStore(&modelsdev.Database{})) + result := ConvertMultiContent(t.Context(), tt.multiContent, modelsdev.ID{}, modelsdev.NewDatabaseStore(&modelsdev.Database{})) assert.Len(t, result, tt.wantCount) }) } @@ -138,7 +138,7 @@ func TestConvertMessages(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - result := ConvertMessages(t.Context(), tt.messages, "", modelsdev.NewDatabaseStore(&modelsdev.Database{})) + result := ConvertMessages(t.Context(), tt.messages, modelsdev.ID{}, modelsdev.NewDatabaseStore(&modelsdev.Database{})) assert.Len(t, result, tt.want) }) } diff --git a/pkg/model/provider/openai/attachments.go b/pkg/model/provider/openai/attachments.go index 800d3c918..c5aba624e 100644 --- a/pkg/model/provider/openai/attachments.go +++ b/pkg/model/provider/openai/attachments.go @@ -25,8 +25,8 @@ import ( // - application/pdf with InlineData → OfInputFile (base64) // - text MIMEs with InlineText → OfInputText with TXTEnvelope // - unsupported / no content → nil (logged as warning) -func convertDocumentToResponseInput(ctx context.Context, doc chat.Document, modelID string, store *modelsdev.Store) ([]responses.ResponseInputContentUnionParam, error) { - mc := modelinfo.LoadCaps(store, modelID) +func convertDocumentToResponseInput(ctx context.Context, doc chat.Document, id modelsdev.ID, store *modelsdev.Store) ([]responses.ResponseInputContentUnionParam, error) { + mc := modelinfo.LoadCaps(store, id) return convertDocumentToResponseInputWithCaps(ctx, doc, mc) } diff --git a/pkg/model/provider/provider.go b/pkg/model/provider/provider.go index 80b46948b..f42a81b41 100644 --- a/pkg/model/provider/provider.go +++ b/pkg/model/provider/provider.go @@ -25,14 +25,19 @@ import ( "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/rag/types" "github.com/docker/docker-agent/pkg/tools" ) // Provider defines the interface for model providers. type Provider interface { - // ID returns the model provider ID - ID() string + // ID returns the provider-qualified model identity. Returning a + // [modelsdev.ID] (rather than a bare string) prevents callers from + // silently forgetting to namespace the model when it crosses an API + // boundary; use [modelsdev.ID.String] when a textual representation + // is required. + ID() modelsdev.ID // CreateChatCompletionStream creates a streaming chat completion request. // It returns a stream that can be iterated over to get completion chunks. CreateChatCompletionStream( diff --git a/pkg/model/provider/router_factory_test.go b/pkg/model/provider/router_factory_test.go index b0a3950e3..4968525bc 100644 --- a/pkg/model/provider/router_factory_test.go +++ b/pkg/model/provider/router_factory_test.go @@ -11,6 +11,7 @@ import ( "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/modelsdev" ) // TestResolveRoutedModel_NamedReference verifies that a model name in the @@ -20,7 +21,7 @@ func TestResolveRoutedModel_NamedReference(t *testing.T) { withFactories(t, map[string]providerFactory{ "openai": func(_ context.Context, cfg *latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) { capturedCfg = cfg - return &fakeProvider{id: "captured"}, nil + return &fakeProvider{id: modelsdev.NewID("openai", "captured")}, nil }, }) @@ -44,7 +45,7 @@ func TestResolveRoutedModel_InlineSpec(t *testing.T) { withFactories(t, map[string]providerFactory{ "openai": func(_ context.Context, cfg *latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) { capturedCfg = cfg - return &fakeProvider{id: "captured"}, nil + return &fakeProvider{id: modelsdev.NewID("openai", "captured")}, nil }, }) @@ -131,7 +132,7 @@ func TestResolveRoutedModel_OptionsForwarded(t *testing.T) { withFactories(t, map[string]providerFactory{ "openai": func(_ context.Context, _ *latest.ModelConfig, _ environment.Provider, opts ...options.Opt) (Provider, error) { capturedOpts = opts - return &fakeProvider{id: "ok"}, nil + return &fakeProvider{id: modelsdev.NewID("openai", "ok")}, nil }, }) diff --git a/pkg/model/provider/rulebased/client.go b/pkg/model/provider/rulebased/client.go index 76c5e9537..9f28a6b91 100644 --- a/pkg/model/provider/rulebased/client.go +++ b/pkg/model/provider/rulebased/client.go @@ -22,12 +22,13 @@ import ( "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/tools" ) // Provider defines the minimal interface needed for model providers. type Provider interface { - ID() string + ID() modelsdev.ID CreateChatCompletionStream( ctx context.Context, messages []chat.Message, @@ -48,7 +49,7 @@ type Client struct { fallback Provider index bleve.Index mu sync.RWMutex - lastSelectedID string // ID of the provider selected by the most recent call + lastSelectedID modelsdev.ID // ID of the provider selected by the most recent call } // NewClient creates a new rule-based routing client. @@ -173,8 +174,8 @@ func (c *Client) CreateChatCompletionStream( c.lastSelectedID = selectedID c.mu.Unlock() slog.DebugContext(ctx, "Rule-based router selected model", - "router", c.ID(), - "selected_model", selectedID, + "router", c.ID().String(), + "selected_model", selectedID.String(), "message_count", len(messages), ) @@ -184,7 +185,7 @@ func (c *Client) CreateChatCompletionStream( // LastSelectedModelID returns the ID of the provider selected by the most // recent CreateChatCompletionStream call. This allows callers to display // the YAML-configured sub-model name for rule-based routing. -func (c *Client) LastSelectedModelID() string { +func (c *Client) LastSelectedModelID() modelsdev.ID { c.mu.RLock() defer c.mu.RUnlock() return c.lastSelectedID @@ -223,7 +224,7 @@ func (c *Client) selectProvider(messages []chat.Message) Provider { selected := c.routes[routeIdx] slog.Debug("Route matched", - "model", selected.ID(), + "model", selected.ID().String(), "score", hit.Score, ) return selected diff --git a/pkg/model/provider/rulebased/client_test.go b/pkg/model/provider/rulebased/client_test.go index 8ef3a2fee..ed537fb58 100644 --- a/pkg/model/provider/rulebased/client_test.go +++ b/pkg/model/provider/rulebased/client_test.go @@ -13,15 +13,16 @@ import ( "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/tools" ) // mockProvider is a simple mock provider for testing. type mockProvider struct { - id string + id modelsdev.ID } -func (m *mockProvider) ID() string { +func (m *mockProvider) ID() modelsdev.ID { return m.id } @@ -41,9 +42,9 @@ func (m *mockProvider) BaseConfig() base.Config { // It resolves model references from the models map or parses inline specs. func mockProviderFactory(_ context.Context, modelSpec string, models map[string]latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) { if cfg, exists := models[modelSpec]; exists { - return &mockProvider{id: cfg.Provider + "/" + cfg.Model}, nil + return &mockProvider{id: modelsdev.NewID(cfg.Provider, cfg.Model)}, nil } - return &mockProvider{id: modelSpec}, nil + return &mockProvider{id: modelsdev.ParseIDOrZero(modelSpec)}, nil } func TestNewClient(t *testing.T) { @@ -201,7 +202,7 @@ func TestClient_SelectProvider(t *testing.T) { messages := []chat.Message{{Role: chat.MessageRoleUser, Content: tt.message}} provider := client.selectProvider(messages) require.NotNil(t, provider) - assert.Equal(t, tt.expectedModel, provider.ID()) + assert.Equal(t, tt.expectedModel, provider.ID().String()) }) } } @@ -288,7 +289,7 @@ func TestClient_ID(t *testing.T) { require.NoError(t, err) defer client.Close() - assert.Equal(t, "openai/gpt-4o", client.ID()) + assert.Equal(t, "openai/gpt-4o", client.ID().String()) } func TestClient_DefaultProvider(t *testing.T) { @@ -310,7 +311,7 @@ func TestClient_DefaultProvider(t *testing.T) { defer client.Close() provider := client.selectProvider(nil) - assert.Equal(t, "openai/gpt-4o", provider.ID()) + assert.Equal(t, "openai/gpt-4o", provider.ID().String()) } func TestClient_CreateChatCompletionStream_NilProvider(t *testing.T) { diff --git a/pkg/model/provider/vertexai/modelgarden.go b/pkg/model/provider/vertexai/modelgarden.go index e0a40bf3a..1cb30002e 100644 --- a/pkg/model/provider/vertexai/modelgarden.go +++ b/pkg/model/provider/vertexai/modelgarden.go @@ -44,6 +44,7 @@ import ( "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/model/provider/openai" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/tools" ) @@ -59,7 +60,7 @@ var validGCPIdentifier = regexp.MustCompile(`^[a-z][a-z0-9-]{1,29}$`) // anthropic.Client and openai.Client satisfy it, so the caller can treat // the two Model Garden code paths uniformly. type Client interface { - ID() string + ID() modelsdev.ID CreateChatCompletionStream(ctx context.Context, messages []chat.Message, tools []tools.Tool) (chat.MessageStream, error) BaseConfig() base.Config } diff --git a/pkg/modelinfo/modelinfo.go b/pkg/modelinfo/modelinfo.go index 24a9e2447..cd57268c1 100644 --- a/pkg/modelinfo/modelinfo.go +++ b/pkg/modelinfo/modelinfo.go @@ -158,27 +158,27 @@ func IsBedrockClaudeID(modelID string) bool { // // Pass a nil store to skip the models.dev lookup entirely (the name-pattern // fallback still works, which is fine for the common case). -func IsClaude(ctx context.Context, store *modelsdev.Store, providerID, modelID string) bool { - if family := LookupFamily(ctx, store, providerID, modelID); family != "" { +func IsClaude(ctx context.Context, store *modelsdev.Store, id modelsdev.ID) bool { + if family := LookupFamily(ctx, store, id); family != "" { return IsClaudeFamily(family) } - if IsBedrockClaudeID(modelID) { + if IsBedrockClaudeID(id.Model) { return true } - return strings.HasPrefix(normalize(modelID), "claude-") + return strings.HasPrefix(normalize(id.Model), "claude-") } // LookupFamily returns the canonical model family identifier from models.dev // (e.g. "claude-opus", "claude-sonnet", "gemini-pro", "o", "o-mini", "gpt"). // -// Returns "" when the store is nil, the providerID/modelID is empty, or the +// Returns "" when the store is nil, the id is incomplete, or the // model is not registered in the database. Callers that want a non-empty // answer for unknown models should fall back to a name-pattern heuristic. -func LookupFamily(ctx context.Context, store *modelsdev.Store, providerID, modelID string) string { - if store == nil || providerID == "" || modelID == "" { +func LookupFamily(ctx context.Context, store *modelsdev.Store, id modelsdev.ID) string { + if store == nil || !id.IsValid() { return "" } - m, err := store.GetModel(ctx, providerID+"/"+modelID) + m, err := store.GetModel(ctx, id) if err != nil || m == nil { return "" } @@ -244,13 +244,11 @@ func (mc ModelCapabilities) Supports(mimeType string) bool { const loadCapsTimeout = 10 * time.Second // LoadCaps fetches (or returns from cache) the capability record for the given -// model ID using the provided store. The model ID should be in -// "provider/model" format as used by models.dev -// (e.g. "anthropic/claude-3-5-sonnet-20241022"). +// model ID using the provided store. // // When the store is nil or the model is not found, LoadCaps returns a // conservative capability set that only allows text MIME types. -func LoadCaps(store *modelsdev.Store, modelID string) ModelCapabilities { +func LoadCaps(store *modelsdev.Store, id modelsdev.ID) ModelCapabilities { if store == nil { return ModelCapabilities{} } @@ -258,11 +256,11 @@ func LoadCaps(store *modelsdev.Store, modelID string) ModelCapabilities { ctx, cancel := context.WithTimeout(context.Background(), loadCapsTimeout) defer cancel() - model, err := store.GetModel(ctx, modelID) + model, err := store.GetModel(ctx, id) if err != nil { if ctx.Err() != nil { slog.WarnContext(ctx, "modelinfo: models.dev lookup timed out, using conservative caps", - "model", modelID, "timeout", loadCapsTimeout) + "model", id.String(), "timeout", loadCapsTimeout) } return ModelCapabilities{} } diff --git a/pkg/modelinfo/modelinfo_test.go b/pkg/modelinfo/modelinfo_test.go index 0f28b45e3..3ca0fb434 100644 --- a/pkg/modelinfo/modelinfo_test.go +++ b/pkg/modelinfo/modelinfo_test.go @@ -253,29 +253,29 @@ func TestLookupFamily(t *testing.T) { t.Run("known", func(t *testing.T) { t.Parallel() - assert.Equal(t, "claude-sonnet", LookupFamily(t.Context(), store, "anthropic", "claude-sonnet-4-5")) + assert.Equal(t, "claude-sonnet", LookupFamily(t.Context(), store, modelsdev.NewID("anthropic", "claude-sonnet-4-5"))) }) t.Run("known on bedrock", func(t *testing.T) { t.Parallel() - got := LookupFamily(t.Context(), store, "amazon-bedrock", "anthropic.claude-sonnet-4-5-20250929-v1:0") + got := LookupFamily(t.Context(), store, modelsdev.NewID("amazon-bedrock", "anthropic.claude-sonnet-4-5-20250929-v1:0")) assert.Equal(t, "claude-sonnet", got) }) t.Run("unknown model", func(t *testing.T) { t.Parallel() - assert.Empty(t, LookupFamily(t.Context(), store, "anthropic", "claude-future")) + assert.Empty(t, LookupFamily(t.Context(), store, modelsdev.NewID("anthropic", "claude-future"))) }) t.Run("unknown provider", func(t *testing.T) { t.Parallel() - assert.Empty(t, LookupFamily(t.Context(), store, "no-such-provider", "x")) + assert.Empty(t, LookupFamily(t.Context(), store, modelsdev.NewID("no-such-provider", "x"))) }) t.Run("nil store", func(t *testing.T) { t.Parallel() - assert.Empty(t, LookupFamily(t.Context(), nil, "anthropic", "claude-sonnet-4-5")) + assert.Empty(t, LookupFamily(t.Context(), nil, modelsdev.NewID("anthropic", "claude-sonnet-4-5"))) }) t.Run("empty inputs", func(t *testing.T) { t.Parallel() - assert.Empty(t, LookupFamily(t.Context(), store, "", "claude-sonnet-4-5")) - assert.Empty(t, LookupFamily(t.Context(), store, "anthropic", "")) + assert.Empty(t, LookupFamily(t.Context(), store, modelsdev.NewID("", "claude-sonnet-4-5"))) + assert.Empty(t, LookupFamily(t.Context(), store, modelsdev.NewID("anthropic", ""))) }) } @@ -300,22 +300,22 @@ func TestIsClaude(t *testing.T) { ctx := t.Context() // Resolved via models.dev. - assert.True(t, IsClaude(ctx, store, "anthropic", "claude-sonnet-4-5")) - assert.True(t, IsClaude(ctx, store, "vertex-anthropic", "claude-opus-4-7")) + assert.True(t, IsClaude(ctx, store, modelsdev.NewID("anthropic", "claude-sonnet-4-5"))) + assert.True(t, IsClaude(ctx, store, modelsdev.NewID("vertex-anthropic", "claude-opus-4-7"))) // Resolved via Bedrock-style name pattern even without store data. - assert.True(t, IsClaude(ctx, nil, "amazon-bedrock", "anthropic.claude-3-5-sonnet-20241022-v2:0")) - assert.True(t, IsClaude(ctx, nil, "amazon-bedrock", "global.anthropic.claude-opus-4-5-20251101-v1:0")) + assert.True(t, IsClaude(ctx, nil, modelsdev.NewID("amazon-bedrock", "anthropic.claude-3-5-sonnet-20241022-v2:0"))) + assert.True(t, IsClaude(ctx, nil, modelsdev.NewID("amazon-bedrock", "global.anthropic.claude-opus-4-5-20251101-v1:0"))) // Resolved via bare-name fallback. - assert.True(t, IsClaude(ctx, nil, "anthropic", "claude-future")) + assert.True(t, IsClaude(ctx, nil, modelsdev.NewID("anthropic", "claude-future"))) // Definitively not Claude. - assert.False(t, IsClaude(ctx, store, "openai", "gpt-4o")) - assert.False(t, IsClaude(ctx, nil, "openai", "gpt-4o")) - assert.False(t, IsClaude(ctx, nil, "amazon-bedrock", "amazon.titan-text-express-v1")) - assert.False(t, IsClaude(ctx, nil, "google", "gemini-2.5-pro")) - assert.False(t, IsClaude(ctx, nil, "", "")) + assert.False(t, IsClaude(ctx, store, modelsdev.NewID("openai", "gpt-4o"))) + assert.False(t, IsClaude(ctx, nil, modelsdev.NewID("openai", "gpt-4o"))) + assert.False(t, IsClaude(ctx, nil, modelsdev.NewID("amazon-bedrock", "amazon.titan-text-express-v1"))) + assert.False(t, IsClaude(ctx, nil, modelsdev.NewID("google", "gemini-2.5-pro"))) + assert.False(t, IsClaude(ctx, nil, modelsdev.ID{})) } func TestIsClaude_StoreErrorFallsBackToPattern(t *testing.T) { @@ -325,8 +325,8 @@ func TestIsClaude_StoreErrorFallsBackToPattern(t *testing.T) { // the bare-name fallback to identify Claude models correctly. store := modelsdev.NewDatabaseStore(&modelsdev.Database{Providers: map[string]modelsdev.Provider{}}) - require.True(t, IsClaude(t.Context(), store, "anthropic", "claude-sonnet-4-5")) - require.False(t, IsClaude(t.Context(), store, "openai", "gpt-4o")) + require.True(t, IsClaude(t.Context(), store, modelsdev.NewID("anthropic", "claude-sonnet-4-5"))) + require.False(t, IsClaude(t.Context(), store, modelsdev.NewID("openai", "gpt-4o"))) } // --------------------------------------------------------------------------- @@ -349,20 +349,20 @@ func TestLoadCaps_QualifiedIDRequired(t *testing.T) { }}) // Bare model name: must fall back to conservative text-only caps. - bareID := "claude-sonnet-4-6" + bareID := modelsdev.NewID("", "claude-sonnet-4-6") mcBare := LoadCaps(store, bareID) assert.False(t, mcBare.Supports("image/jpeg"), - "bare model name %q must NOT resolve to vision caps", bareID) + "bare model name %q must NOT resolve to vision caps", bareID.String()) assert.False(t, mcBare.Supports("application/pdf"), - "bare model name %q must NOT resolve to PDF caps", bareID) + "bare model name %q must NOT resolve to PDF caps", bareID.String()) // Fully-qualified ID: must resolve to vision+pdf caps. - qualifiedID := "anthropic/claude-sonnet-4-6" + qualifiedID := modelsdev.NewID("anthropic", "claude-sonnet-4-6") mcQualified := LoadCaps(store, qualifiedID) assert.True(t, mcQualified.Supports("image/jpeg"), - "qualified ID %q must resolve to vision caps", qualifiedID) + "qualified ID %q must resolve to vision caps", qualifiedID.String()) assert.True(t, mcQualified.Supports("application/pdf"), - "qualified ID %q must resolve to PDF caps", qualifiedID) + "qualified ID %q must resolve to PDF caps", qualifiedID.String()) } func TestLoadCaps_VisionModel(t *testing.T) { @@ -380,7 +380,7 @@ func TestLoadCaps_VisionModel(t *testing.T) { }, }}) - mc := LoadCaps(store, "anthropic/claude-3-5-sonnet") + mc := LoadCaps(store, modelsdev.NewID("anthropic", "claude-3-5-sonnet")) assert.True(t, mc.Supports("image/jpeg")) assert.True(t, mc.Supports("image/png")) @@ -403,7 +403,7 @@ func TestLoadCaps_TextOnlyModel(t *testing.T) { }, }}) - mc := LoadCaps(store, "openai/gpt-3.5-turbo") + mc := LoadCaps(store, modelsdev.NewID("openai", "gpt-3.5-turbo")) assert.False(t, mc.Supports("image/jpeg")) assert.False(t, mc.Supports("application/pdf")) @@ -414,7 +414,7 @@ func TestLoadCaps_TextOnlyModel(t *testing.T) { func TestLoadCaps_ModelNotFound(t *testing.T) { store := modelsdev.NewDatabaseStore(&modelsdev.Database{Providers: map[string]modelsdev.Provider{}}) - mc := LoadCaps(store, "unknown/nonexistent-model") + mc := LoadCaps(store, modelsdev.NewID("unknown", "nonexistent-model")) assert.False(t, mc.Supports("image/jpeg")) assert.False(t, mc.Supports("application/pdf")) @@ -436,7 +436,7 @@ func TestLoadCaps_OfficeDocsNotAllowed(t *testing.T) { }, }}) - mc := LoadCaps(store, "openai/gpt-4o") + mc := LoadCaps(store, modelsdev.NewID("openai", "gpt-4o")) for _, officeMIME := range []string{ "application/vnd.openxmlformats-officedocument.wordprocessingml.document", diff --git a/pkg/modelsdev/id.go b/pkg/modelsdev/id.go new file mode 100644 index 000000000..130de9037 --- /dev/null +++ b/pkg/modelsdev/id.go @@ -0,0 +1,74 @@ +package modelsdev + +import ( + "fmt" + "strings" +) + +// ID identifies a model in the models.dev catalog by provider and model +// name. It exists so callers can no longer accidentally pass a bare model +// name (e.g. "claude-sonnet-4-6") where a "provider/model" pair is required: +// the compiler rejects a [string] argument and forces one of the +// constructors below. +// +// The zero value is the empty ID and reports IsZero() == true. Use +// [NewID], [ParseID], or [ParseIDOrZero] to construct values; use +// [ID.String] when a "provider/model" representation is required at a +// boundary (slog fields, event payloads, error messages, ...). +type ID struct { + Provider string + Model string +} + +// NewID returns an ID for the given provider and model name. Either +// component may be empty (for example for a provider-less model spec +// during config parsing); call [ID.IsZero] to test for the empty ID and +// [ID.IsValid] to check that both components are populated. +func NewID(provider, model string) ID { + return ID{Provider: provider, Model: model} +} + +// ParseID parses a "provider/model" reference. Either component must be +// non-empty and the separator must be present; otherwise an error is +// returned. The function does not validate that the provider or model +// exists in the models.dev catalog. +func ParseID(ref string) (ID, error) { + provider, model, ok := strings.Cut(ref, "/") + if !ok || provider == "" || model == "" { + return ID{}, fmt.Errorf("invalid model reference %q: expected 'provider/model' format", ref) + } + return ID{Provider: provider, Model: model}, nil +} + +// ParseIDOrZero parses a "provider/model" reference and returns the zero +// ID when the input is malformed. Use this on best-effort code paths +// (logs, telemetry labels) where a malformed reference should not +// surface as an error. +func ParseIDOrZero(ref string) ID { + id, err := ParseID(ref) + if err != nil { + return ID{} + } + return id +} + +// String returns the canonical "provider/model" representation. When +// either component is empty the separator is still emitted so the +// output round-trips through [ParseID] only when both fields are set. +// For the zero ID the result is the empty string. +func (id ID) String() string { + if id.IsZero() { + return "" + } + return id.Provider + "/" + id.Model +} + +// IsZero reports whether the ID has both components empty. +func (id ID) IsZero() bool { + return id.Provider == "" && id.Model == "" +} + +// IsValid reports whether both components of the ID are populated. +func (id ID) IsValid() bool { + return id.Provider != "" && id.Model != "" +} diff --git a/pkg/modelsdev/id_test.go b/pkg/modelsdev/id_test.go new file mode 100644 index 000000000..3db9ef906 --- /dev/null +++ b/pkg/modelsdev/id_test.go @@ -0,0 +1,93 @@ +package modelsdev + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewID(t *testing.T) { + t.Parallel() + + id := NewID("openai", "gpt-4o") + assert.Equal(t, "openai", id.Provider) + assert.Equal(t, "gpt-4o", id.Model) + assert.Equal(t, "openai/gpt-4o", id.String()) + assert.True(t, id.IsValid()) + assert.False(t, id.IsZero()) +} + +func TestParseID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ref string + wantID ID + wantErr bool + wantStringR bool // round-trips via String() + }{ + {"valid", "openai/gpt-4o", ID{Provider: "openai", Model: "gpt-4o"}, false, true}, + {"valid with slash in model", "openai/foo/bar", ID{Provider: "openai", Model: "foo/bar"}, false, false}, + {"missing separator", "openai-gpt-4o", ID{}, true, false}, + {"missing provider", "/gpt-4o", ID{}, true, false}, + {"missing model", "openai/", ID{}, true, false}, + {"empty", "", ID{}, true, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := ParseID(tt.ref) + if tt.wantErr { + require.Error(t, err) + assert.Equal(t, ID{}, got) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantID, got) + if tt.wantStringR { + assert.Equal(t, tt.ref, got.String()) + } + }) + } +} + +func TestParseIDOrZero(t *testing.T) { + t.Parallel() + + id := ParseIDOrZero("openai/gpt-4o") + assert.Equal(t, ID{Provider: "openai", Model: "gpt-4o"}, id) + + id = ParseIDOrZero("not-a-ref") + assert.True(t, id.IsZero()) + assert.Empty(t, id.String()) +} + +func TestIDZero(t *testing.T) { + t.Parallel() + + var id ID + assert.True(t, id.IsZero()) + assert.False(t, id.IsValid()) + assert.Empty(t, id.String()) +} + +func TestIDIsValid(t *testing.T) { + t.Parallel() + + tests := []struct { + id ID + want bool + }{ + {ID{Provider: "openai", Model: "gpt-4o"}, true}, + {ID{Provider: "openai"}, false}, + {ID{Model: "gpt-4o"}, false}, + {ID{}, false}, + } + for _, tt := range tests { + assert.Equal(t, tt.want, tt.id.IsValid(), "id=%+v", tt.id) + } +} diff --git a/pkg/modelsdev/store.go b/pkg/modelsdev/store.go index 0208aa825..1c7438829 100644 --- a/pkg/modelsdev/store.go +++ b/pkg/modelsdev/store.go @@ -120,33 +120,32 @@ func (s *Store) getProvider(ctx context.Context, providerID string) (*Provider, return &provider, nil } -// GetModel returns a specific model by provider ID and model ID. -func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) { - parts := strings.SplitN(id, "/", 2) - if len(parts) != 2 { - return nil, fmt.Errorf("invalid model ID: %q", id) +// GetModel returns a specific model by ID. The ID must carry both a +// provider and a model component; pass the result of [NewID], [ParseID], +// or a provider's [ID] method. +func (s *Store) GetModel(ctx context.Context, id ID) (*Model, error) { + if !id.IsValid() { + return nil, fmt.Errorf("invalid model ID: %q", id.String()) } - providerID := parts[0] - modelID := parts[1] - provider, err := s.getProvider(ctx, providerID) + provider, err := s.getProvider(ctx, id.Provider) if err != nil { return nil, err } - model, exists := provider.Models[modelID] + model, exists := provider.Models[id.Model] // For amazon-bedrock, try stripping region/inference profile prefixes. // Bedrock uses prefixes for cross-region inference profiles, // but models.dev stores models without these prefixes. - if !exists && providerID == "amazon-bedrock" { - if prefix, after, ok := strings.Cut(modelID, "."); ok && bedrockRegionPrefixes[prefix] { + if !exists && id.Provider == "amazon-bedrock" { + if prefix, after, ok := strings.Cut(id.Model, "."); ok && bedrockRegionPrefixes[prefix] { model, exists = provider.Models[after] } } if !exists { - return nil, fmt.Errorf("model %q not found in provider %q", modelID, providerID) + return nil, fmt.Errorf("model %q not found in provider %q", id.Model, id.Provider) } return &model, nil diff --git a/pkg/rag/rerank/rerank_test.go b/pkg/rag/rerank/rerank_test.go index 0a514fae3..84b536ecc 100644 --- a/pkg/rag/rerank/rerank_test.go +++ b/pkg/rag/rerank/rerank_test.go @@ -10,6 +10,7 @@ import ( "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/model/provider/base" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/rag/database" "github.com/docker/docker-agent/pkg/rag/types" "github.com/docker/docker-agent/pkg/tools" @@ -23,8 +24,8 @@ type fakeRerankingProvider struct { err error } -func (f *fakeRerankingProvider) ID() string { - return "fake-reranker" +func (f *fakeRerankingProvider) ID() modelsdev.ID { + return modelsdev.NewID("test", "fake-reranker") } func (f *fakeRerankingProvider) CreateChatCompletionStream( @@ -66,8 +67,8 @@ type fakeProviderWithoutRerank struct { base.Config } -func (f *fakeProviderWithoutRerank) ID() string { - return "fake-no-rerank" +func (f *fakeProviderWithoutRerank) ID() modelsdev.ID { + return modelsdev.NewID("test", "fake-no-rerank") } func (f *fakeProviderWithoutRerank) CreateChatCompletionStream( diff --git a/pkg/rag/strategy/embedding.go b/pkg/rag/strategy/embedding.go index f9f133e58..def63a2ee 100644 --- a/pkg/rag/strategy/embedding.go +++ b/pkg/rag/strategy/embedding.go @@ -16,7 +16,7 @@ import ( // EmbeddingConfig holds configuration for creating an embedding provider. type EmbeddingConfig struct { Provider provider.Provider - ModelID string // Full model ID for pricing (e.g., "openai/text-embedding-3-small") + ModelID modelsdev.ID // Provider/model identity, used for pricing lookup. ModelsStore *modelsdev.Store } @@ -47,11 +47,11 @@ func CreateEmbeddingProvider(ctx context.Context, modelName string, buildCtx Bui } // Determine model ID for pricing lookup - var modelID string + var modelID modelsdev.ID if modelName == "auto" { modelID = embedModel.ID() } else { - modelID = modelCfg.Provider + "/" + modelCfg.Model + modelID = modelsdev.NewID(modelCfg.Provider, modelCfg.Model) } var modelsStore *modelsdev.Store diff --git a/pkg/rag/strategy/semantic_embeddings.go b/pkg/rag/strategy/semantic_embeddings.go index 06b619159..c478dc529 100644 --- a/pkg/rag/strategy/semantic_embeddings.go +++ b/pkg/rag/strategy/semantic_embeddings.go @@ -16,6 +16,7 @@ import ( "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/js" "github.com/docker/docker-agent/pkg/model/provider" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/rag/chunk" "github.com/docker/docker-agent/pkg/rag/types" "github.com/docker/docker-agent/pkg/tools" @@ -94,8 +95,8 @@ func NewSemanticEmbeddingsFromConfig(ctx context.Context, cfg latest.RAGStrategy } chatModelID := chatProvider.ID() - if chatModelID == "" && chatModelCfg.Provider != "" && chatModelCfg.Model != "" { - chatModelID = fmt.Sprintf("%s/%s", chatModelCfg.Provider, chatModelCfg.Model) + if chatModelID.IsZero() && chatModelCfg.Provider != "" && chatModelCfg.Model != "" { + chatModelID = modelsdev.NewID(chatModelCfg.Provider, chatModelCfg.Model) } // Get optional parameters with defaults @@ -500,15 +501,15 @@ func humanizeMetadataKey(key string) string { } // calculateSemanticUsageCost calculates cost for semantic LLM usage. -func calculateSemanticUsageCost(modelsStore modelStore, modelID string, usage *chat.Usage) float64 { - if usage == nil || modelsStore == nil || modelID == "" || strings.HasPrefix(modelID, "dmr/") { +func calculateSemanticUsageCost(modelsStore modelStore, id modelsdev.ID, usage *chat.Usage) float64 { + if usage == nil || modelsStore == nil || !id.IsValid() || id.Provider == "dmr" { return 0 } - model, err := modelsStore.GetModel(context.Background(), modelID) + model, err := modelsStore.GetModel(context.Background(), id) if err != nil { slog.Debug("Failed to get semantic model pricing from models.dev, cost will be 0", - "model_id", modelID, + "model_id", id.String(), "error", err) return 0 } diff --git a/pkg/rag/strategy/vector_store.go b/pkg/rag/strategy/vector_store.go index 1574c6b77..9d4935b92 100644 --- a/pkg/rag/strategy/vector_store.go +++ b/pkg/rag/strategy/vector_store.go @@ -63,7 +63,7 @@ type VectorStore struct { indexingTokens int64 // Track tokens used during indexing indexingCost float64 - modelID string // Full model ID (e.g., "openai/text-embedding-3-small") for pricing lookup + modelID modelsdev.ID // Provider/model identity, used for pricing lookup. modelsStore modelStore // embeddingInputBuilder controls how raw chunks are transformed into the @@ -88,7 +88,7 @@ type VectorStore struct { } type modelStore interface { - GetModel(ctx context.Context, modelID string) (*modelsdev.Model, error) + GetModel(ctx context.Context, id modelsdev.ID) (*modelsdev.Model, error) } // EmbeddingInputBuilder builds the string that will be sent to the embedding model @@ -114,7 +114,7 @@ type VectorStoreConfig struct { Embedder *embed.Embedder Events chan<- types.Event SimilarityMetric string - ModelID string + ModelID modelsdev.ID ModelsStore modelStore EmbeddingConcurrency int FileIndexConcurrency int @@ -171,14 +171,14 @@ func (s *VectorStore) SetEmbeddingInputBuilder(builder EmbeddingInputBuilder) { // calculateCost calculates embedding cost using models.dev pricing func (s *VectorStore) calculateCost(tokens int64) float64 { - if s.modelsStore == nil || strings.HasPrefix(s.modelID, "dmr/") { + if s.modelsStore == nil || s.modelID.Provider == "dmr" { return 0 } model, err := s.modelsStore.GetModel(context.Background(), s.modelID) if err != nil { slog.Debug("Failed to get model pricing from models.dev, cost will be 0", - "model_id", s.modelID, + "model_id", s.modelID.String(), "error", err) return 0 } diff --git a/pkg/runtime/agent_delegation.go b/pkg/runtime/agent_delegation.go index 1c24a393a..5d0d04f17 100644 --- a/pkg/runtime/agent_delegation.go +++ b/pkg/runtime/agent_delegation.go @@ -218,12 +218,12 @@ func (r *LocalRuntime) swapCurrentAgent(ctx context.Context, sessionID string, f evts.Emit(AgentSwitching(true, from.Name(), to.Name())) r.executeOnAgentSwitchHooks(ctx, from, sessionID, from.Name(), to.Name(), agentSwitchKindTransferTask) r.setCurrentAgent(to.Name()) - evts.Emit(AgentInfo(to.Name(), getAgentModelID(to), to.Description(), to.WelcomeMessage())) + evts.Emit(AgentInfo(to.Name(), getAgentModelID(to).String(), to.Description(), to.WelcomeMessage())) return func() { r.setCurrentAgent(from.Name()) evts.Emit(AgentSwitching(false, to.Name(), from.Name())) r.executeOnAgentSwitchHooks(ctx, from, sessionID, to.Name(), from.Name(), agentSwitchKindTransferTaskReturn) - evts.Emit(AgentInfo(from.Name(), getAgentModelID(from), from.Description(), from.WelcomeMessage())) + evts.Emit(AgentInfo(from.Name(), getAgentModelID(from).String(), from.Description(), from.WelcomeMessage())) } } diff --git a/pkg/runtime/cache.go b/pkg/runtime/cache.go index 03bc62df7..c022de69a 100644 --- a/pkg/runtime/cache.go +++ b/pkg/runtime/cache.go @@ -75,7 +75,7 @@ func (r *LocalRuntime) tryReplayCachedResponse( slog.DebugContext(ctx, "Response cache hit; replaying cached answer", "agent", a.Name(), "session_id", sess.ID) - modelID := a.Model(ctx).ID() + modelID := a.Model(ctx).ID().String() events.Emit(AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage())) addAgentMessage(sess, a, &chat.Message{ Role: chat.MessageRoleAssistant, diff --git a/pkg/runtime/compactor/compactor_test.go b/pkg/runtime/compactor/compactor_test.go index f85f0f6db..492737831 100644 --- a/pkg/runtime/compactor/compactor_test.go +++ b/pkg/runtime/compactor/compactor_test.go @@ -13,13 +13,14 @@ import ( "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/compaction" "github.com/docker/docker-agent/pkg/model/provider/base" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/tools" ) -type fakeProvider struct{ id string } +type fakeProvider struct{ id modelsdev.ID } -func (p fakeProvider) ID() string { return p.id } +func (p fakeProvider) ID() modelsdev.ID { return p.id } func (p fakeProvider) BaseConfig() base.Config { return base.Config{} } @@ -350,7 +351,7 @@ func TestRunLLM_DoesNotDuplicateSystemPrompt(t *testing.T) { sess := session.New(session.WithMessages([]session.Item{ session.NewMessageItem(&session.Message{Message: chat.Message{Role: chat.MessageRoleUser, Content: "please summarize"}}), })) - a := agent.New("test", "parent prompt", agent.WithModel(fakeProvider{id: "fake/model"})) + a := agent.New("test", "parent prompt", agent.WithModel(fakeProvider{id: modelsdev.NewID("fake", "model")})) var systemPromptCount int result, err := RunLLM(t.Context(), LLMArgs{ diff --git a/pkg/runtime/fallback.go b/pkg/runtime/fallback.go index dab915ceb..15216a519 100644 --- a/pkg/runtime/fallback.go +++ b/pkg/runtime/fallback.go @@ -92,7 +92,7 @@ func logFallbackAttempt(agentName string, model modelWithFallback, attempt, maxR if model.isFallback { slog.Warn("Fallback model attempt", "agent", agentName, - "model", model.provider.ID(), + "model", model.provider.ID().String(), "fallback_index", model.index, "attempt", attempt+1, "max_retries", maxRetries+1, @@ -100,16 +100,16 @@ func logFallbackAttempt(agentName string, model modelWithFallback, attempt, maxR } else { slog.Warn("Primary model failed, trying fallbacks", "agent", agentName, - "model", model.provider.ID(), + "model", model.provider.ID().String(), "error", err) } } // logRetryBackoff logs when we're backing off before a retry -func logRetryBackoff(agentName, modelID string, attempt int, backoffDelay time.Duration) { +func logRetryBackoff(agentName string, modelID modelsdev.ID, attempt int, backoffDelay time.Duration) { slog.Debug("Backing off before retry", "agent", agentName, - "model", modelID, + "model", modelID.String(), "attempt", attempt+1, "backoff", backoffDelay) } @@ -274,8 +274,8 @@ func (e *fallbackExecutor) execute( } events.Emit(ModelFallback( a.Name(), - prevModelID, - modelEntry.provider.ID(), + prevModelID.String(), + modelEntry.provider.ID().String(), reason, attempt+1, maxAttempts, @@ -284,7 +284,7 @@ func (e *fallbackExecutor) execute( slog.DebugContext(ctx, "Creating chat completion stream", "agent", a.Name(), - "model", modelEntry.provider.ID(), + "model", modelEntry.provider.ID().String(), "is_fallback", modelEntry.isFallback, "in_cooldown", startIndex > 0, "attempt", attempt+1) @@ -302,13 +302,15 @@ func (e *fallbackExecutor) execute( continue } - slog.DebugContext(ctx, "Processing stream", "agent", a.Name(), "model", modelEntry.provider.ID()) + slog.DebugContext(ctx, "Processing stream", "agent", a.Name(), "model", modelEntry.provider.ID().String()) // If the provider is a rule-based router, notify the sidebar // of the selected sub-model's YAML-configured name. - if rp, ok := modelEntry.provider.(interface{ LastSelectedModelID() string }); ok { - if selected := rp.LastSelectedModelID(); selected != "" { - events.Emit(AgentInfo(a.Name(), selected, a.Description(), a.WelcomeMessage())) + if rp, ok := modelEntry.provider.(interface { + LastSelectedModelID() modelsdev.ID + }); ok { + if selected := rp.LastSelectedModelID(); !selected.IsZero() { + events.Emit(AgentInfo(a.Name(), selected.String(), a.Description(), a.WelcomeMessage())) } } @@ -384,7 +386,7 @@ func (e *fallbackExecutor) handleModelError( if !e.retryOnRateLimit || hasFallbacks { slog.WarnContext(ctx, "Rate limited, treating as non-retryable", "agent", a.Name(), - "model", modelEntry.provider.ID(), + "model", modelEntry.provider.ID().String(), "retry_on_rate_limit_enabled", e.retryOnRateLimit, "has_fallbacks", hasFallbacks, "error", err) @@ -401,14 +403,14 @@ func (e *fallbackExecutor) handleModelError( } else if waitDuration > backoff.MaxRetryAfterWait { slog.WarnContext(ctx, "Retry-After exceeds maximum, capping", "agent", a.Name(), - "model", modelEntry.provider.ID(), + "model", modelEntry.provider.ID().String(), "retry_after", retryAfter, "max", backoff.MaxRetryAfterWait) waitDuration = backoff.MaxRetryAfterWait } slog.WarnContext(ctx, "Rate limited, retrying (opt-in enabled)", "agent", a.Name(), - "model", modelEntry.provider.ID(), + "model", modelEntry.provider.ID().String(), "attempt", attempt+1, "wait", waitDuration, "retry_after_from_header", retryAfter > 0, @@ -422,7 +424,7 @@ func (e *fallbackExecutor) handleModelError( if !retryable { slog.ErrorContext(ctx, "Non-retryable error from model", "agent", a.Name(), - "model", modelEntry.provider.ID(), + "model", modelEntry.provider.ID().String(), "error", err) if !modelEntry.isFallback { *primaryFailedWithNonRetryable = true @@ -432,7 +434,7 @@ func (e *fallbackExecutor) handleModelError( slog.WarnContext(ctx, "Retryable error from model", "agent", a.Name(), - "model", modelEntry.provider.ID(), + "model", modelEntry.provider.ID().String(), "attempt", attempt+1, "error", err) return retryDecisionContinue diff --git a/pkg/runtime/fallback_test.go b/pkg/runtime/fallback_test.go index d3babe169..85dbfa15e 100644 --- a/pkg/runtime/fallback_test.go +++ b/pkg/runtime/fallback_test.go @@ -15,6 +15,7 @@ import ( "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/modelerrors" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/team" "github.com/docker/docker-agent/pkg/tools" @@ -26,7 +27,7 @@ type failingProvider struct { err error } -func (p *failingProvider) ID() string { return p.id } +func (p *failingProvider) ID() modelsdev.ID { return modelsdev.ParseIDOrZero(p.id) } func (p *failingProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { return nil, p.err } @@ -42,7 +43,7 @@ type countingProvider struct { stream chat.MessageStream } -func (p *countingProvider) ID() string { return p.id } +func (p *countingProvider) ID() modelsdev.ID { return modelsdev.ParseIDOrZero(p.id) } func (p *countingProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { p.callCount++ if p.callCount <= p.failCount { @@ -71,7 +72,7 @@ func TestBuildModelChain(t *testing.T) { t.Parallel() chain := buildModelChain(primary, nil) require.Len(t, chain, 1) - assert.Equal(t, primary.ID(), chain[0].provider.ID()) + assert.Equal(t, primary.ID().String(), chain[0].provider.ID().String()) assert.False(t, chain[0].isFallback) assert.Equal(t, -1, chain[0].index) }) @@ -81,14 +82,14 @@ func TestBuildModelChain(t *testing.T) { chain := buildModelChain(primary, []provider.Provider{fallback1, fallback2}) require.Len(t, chain, 3) - assert.Equal(t, primary.ID(), chain[0].provider.ID()) + assert.Equal(t, primary.ID().String(), chain[0].provider.ID().String()) assert.False(t, chain[0].isFallback) - assert.Equal(t, fallback1.ID(), chain[1].provider.ID()) + assert.Equal(t, fallback1.ID().String(), chain[1].provider.ID().String()) assert.True(t, chain[1].isFallback) assert.Equal(t, 0, chain[1].index) - assert.Equal(t, fallback2.ID(), chain[2].provider.ID()) + assert.Equal(t, fallback2.ID().String(), chain[2].provider.ID().String()) assert.True(t, chain[2].isFallback) assert.Equal(t, 1, chain[2].index) }) diff --git a/pkg/runtime/lazy_model_store.go b/pkg/runtime/lazy_model_store.go index b19c35d52..7bab630f9 100644 --- a/pkg/runtime/lazy_model_store.go +++ b/pkg/runtime/lazy_model_store.go @@ -31,12 +31,12 @@ func (l *lazyModelStore) load() (*modelsdev.Store, error) { return l.st, l.err } -func (l *lazyModelStore) GetModel(ctx context.Context, modelID string) (*modelsdev.Model, error) { +func (l *lazyModelStore) GetModel(ctx context.Context, id modelsdev.ID) (*modelsdev.Model, error) { st, err := l.load() if err != nil { return nil, err } - return st.GetModel(ctx, modelID) + return st.GetModel(ctx, id) } func (l *lazyModelStore) GetDatabase(ctx context.Context) (*modelsdev.Database, error) { diff --git a/pkg/runtime/lazy_model_store_test.go b/pkg/runtime/lazy_model_store_test.go index 4630f5c96..b6e07e1e3 100644 --- a/pkg/runtime/lazy_model_store_test.go +++ b/pkg/runtime/lazy_model_store_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/team" ) @@ -69,7 +70,7 @@ func TestLazyModelStore_DefersError(t *testing.T) { l.err = wantErr }) - _, err := l.GetModel(t.Context(), "anything") + _, err := l.GetModel(t.Context(), modelsdev.NewID("openai", "anything")) require.ErrorIs(t, err, wantErr) _, err = l.GetDatabase(t.Context()) require.ErrorIs(t, err, wantErr) diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index 443d647af..a0a441970 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -343,7 +343,7 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session, "model_override", ls.toolModelOverride, "error", err) } else { slog.InfoContext(ctx, "Using per-tool model override for this turn", - "agent", a.Name(), "override", overrideModel.ID(), "primary", model.ID()) + "agent", a.Name(), "override", overrideModel.ID().String(), "primary", model.ID().String()) model = overrideModel } ls.toolModelOverride = "" @@ -354,10 +354,10 @@ func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session, // Notify sidebar of the model for this turn. For rule-based // routing, the actual routed model is emitted from within the // stream once the first chunk arrives. - sink.Emit(AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage())) + sink.Emit(AgentInfo(a.Name(), modelID.String(), a.Description(), a.WelcomeMessage())) - slog.DebugContext(ctx, "Using agent", "agent", a.Name(), "model", modelID) - slog.DebugContext(ctx, "Getting model definition", "model_id", modelID) + slog.DebugContext(ctx, "Using agent", "agent", a.Name(), "model", modelID.String()) + slog.DebugContext(ctx, "Getting model definition", "model_id", modelID.String()) m, err := r.modelsStore.GetModel(ctx, modelID) if err != nil { slog.DebugContext(ctx, "Failed to get model definition", "error", err) @@ -452,7 +452,7 @@ func (r *LocalRuntime) runTurn( a *agent.Agent, m *modelsdev.Model, model provider.Provider, - modelID string, + modelID modelsdev.ID, contextLimit int64, sessionSpan trace.Span, agentTools []tools.Tool, @@ -519,7 +519,7 @@ func (r *LocalRuntime) runTurn( // runtime's Go-only message transforms so a hook that drops a // message (e.g. a custom "strip system reminders") doesn't get // silently overridden by a transform later in the chain. - stop, msg, rewritten := r.executeBeforeLLMCallHooks(ctx, sess, a, modelID, ls.iteration, messages) + stop, msg, rewritten := r.executeBeforeLLMCallHooks(ctx, sess, a, modelID.String(), ls.iteration, messages) if stop { slog.WarnContext(ctx, "before_llm_call hook signalled run termination", "agent", a.Name(), "session_id", sess.ID, "reason", msg) @@ -540,7 +540,7 @@ func (r *LocalRuntime) runTurn( // passed explicitly so transforms see the actual model the // loop chose (per-tool override + alloy-mode selection), // not whatever a fresh agent.Model() call would re-randomize. - messages = r.applyBeforeLLMCallTransforms(ctx, sess, a, modelID, messages) + messages = r.applyBeforeLLMCallTransforms(ctx, sess, a, modelID.String(), messages) // Try primary model with fallback chain if configured res, usedModel, err := r.fallback.execute(streamCtx, a, model, messages, agentTools, sess, m, events) @@ -564,8 +564,8 @@ func (r *LocalRuntime) runTurn( r.executeAfterLLMCallHooks(ctx, sess, a, res.Content) if usedModel != nil && usedModel.ID() != model.ID() { - slog.InfoContext(ctx, "Used fallback model", "agent", a.Name(), "primary", model.ID(), "used", usedModel.ID()) - events.Emit(AgentInfo(a.Name(), usedModel.ID(), a.Description(), a.WelcomeMessage())) + slog.InfoContext(ctx, "Used fallback model", "agent", a.Name(), "primary", model.ID().String(), "used", usedModel.ID().String()) + events.Emit(AgentInfo(a.Name(), usedModel.ID().String(), a.Description(), a.WelcomeMessage())) } streamSpan.SetAttributes( attribute.Int("tool.calls", len(res.Calls)), @@ -575,7 +575,7 @@ func (r *LocalRuntime) runTurn( endStreamSpan() slog.DebugContext(ctx, "Stream processed", "agent", a.Name(), "tool_calls", len(res.Calls), "content_length", len(res.Content), "stopped", res.Stopped) - msgUsage := r.recordAssistantMessage(sess, a, res, agentTools, modelID, m, events) + msgUsage := r.recordAssistantMessage(sess, a, res, agentTools, modelID.String(), m, events) usage := SessionUsage(sess, contextLimit) usage.LastMessage = msgUsage diff --git a/pkg/runtime/loop_steps_test.go b/pkg/runtime/loop_steps_test.go index 216a381ea..8c60a2c25 100644 --- a/pkg/runtime/loop_steps_test.go +++ b/pkg/runtime/loop_steps_test.go @@ -182,7 +182,7 @@ func TestHandleStreamError_ContextCanceled_Fatal(t *testing.T) { // trying to drive a real compaction agent. type errModelStore struct{ ModelStore } -func (errModelStore) GetModel(_ context.Context, _ string) (*modelsdev.Model, error) { +func (errModelStore) GetModel(_ context.Context, _ modelsdev.ID) (*modelsdev.Model, error) { return nil, errors.New("no model") } diff --git a/pkg/runtime/model_picker.go b/pkg/runtime/model_picker.go index f788a11a7..fb625643e 100644 --- a/pkg/runtime/model_picker.go +++ b/pkg/runtime/model_picker.go @@ -71,7 +71,7 @@ func (r *LocalRuntime) setModelAndEmitInfo(ctx context.Context, modelRef string, } if a, err := r.team.Agent(currentName); err == nil { - events.Emit(AgentInfo(a.Name(), r.getEffectiveModelID(a), a.Description(), a.WelcomeMessage())) + events.Emit(AgentInfo(a.Name(), r.getEffectiveModelID(a).String(), a.Description(), a.WelcomeMessage())) } else { slog.WarnContext(ctx, "Failed to retrieve agent after model change; UI may not reflect the update", "agent", currentName, "error", err) } diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index d488b6e9e..50231dee4 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -134,7 +134,7 @@ func (r *LocalRuntime) setAgentModelInternal(ctx context.Context, agentName, mod return agent.ModelOverrideSnapshot{}, fmt.Errorf("failed to create model from config: %w", err) } snap := a.SetModelOverride(prov) - slog.InfoContext(ctx, "Set agent model override", "agent", agentName, "model", prov.ID(), "config_name", modelRef) + slog.InfoContext(ctx, "Set agent model override", "agent", agentName, "model", prov.ID().String(), "config_name", modelRef) return snap, nil } @@ -156,7 +156,7 @@ func (r *LocalRuntime) setAgentModelInternal(ctx context.Context, agentName, mod return agent.ModelOverrideSnapshot{}, fmt.Errorf("failed to resolve model %q: %w", modelRef, err) } snap := a.SetModelOverride(prov) - slog.InfoContext(ctx, "Set agent model override (inline)", "agent", agentName, "model", prov.ID()) + slog.InfoContext(ctx, "Set agent model override (inline)", "agent", agentName, "model", prov.ID().String()) return snap, nil } @@ -405,7 +405,7 @@ func (r *LocalRuntime) populateCatalogMetadata(ctx context.Context, choice *Mode if r.modelsStore == nil { return } - m, err := r.modelsStore.GetModel(ctx, providerID+"/"+modelID) + m, err := r.modelsStore.GetModel(ctx, modelsdev.NewID(providerID, modelID)) if err == nil { applyCatalogMetadata(choice, m) } @@ -543,7 +543,7 @@ func (r *LocalRuntime) createProviderFromConfig(ctx context.Context, cfg *latest if cfg.MaxTokens != nil { opts = append(opts, options.WithMaxTokens(*cfg.MaxTokens)) } else if r.modelsStore != nil { - m, err := r.modelsStore.GetModel(ctx, cfg.Provider+"/"+cfg.Model) + m, err := r.modelsStore.GetModel(ctx, modelsdev.NewID(cfg.Provider, cfg.Model)) if err == nil && m != nil { opts = append(opts, options.WithMaxTokens(m.Limit.Output)) } diff --git a/pkg/runtime/on_agent_switch_test.go b/pkg/runtime/on_agent_switch_test.go index d9e8c0ca7..8c1d436ef 100644 --- a/pkg/runtime/on_agent_switch_test.go +++ b/pkg/runtime/on_agent_switch_test.go @@ -13,6 +13,7 @@ import ( "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/hooks" "github.com/docker/docker-agent/pkg/model/provider/base" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/team" "github.com/docker/docker-agent/pkg/tools" ) @@ -128,7 +129,7 @@ type endpointProvider struct { cfg base.Config } -func (p *endpointProvider) ID() string { return p.cfg.ID() } +func (p *endpointProvider) ID() modelsdev.ID { return p.cfg.ID() } func (p *endpointProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { return &mockStream{}, nil diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index d60c91762..7e9fccdd8 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -159,7 +159,7 @@ type CurrentAgentInfo struct { } type ModelStore interface { - GetModel(ctx context.Context, modelID string) (*modelsdev.Model, error) + GetModel(ctx context.Context, id modelsdev.ID) (*modelsdev.Model, error) GetDatabase(ctx context.Context) (*modelsdev.Database, error) } @@ -852,18 +852,19 @@ func (r *LocalRuntime) TitleGenerator() *sessiontitle.Generator { return sessiontitle.New(model, a.FallbackModels()...) } -// getAgentModelID returns the model ID for an agent, or empty string if no model is set. -func getAgentModelID(a *agent.Agent) string { +// getAgentModelID returns the model ID for an agent. The zero ID is +// returned when no model is configured. +func getAgentModelID(a *agent.Agent) modelsdev.ID { if model := a.Model(context.TODO()); model != nil { return model.ID() } - return "" + return modelsdev.ID{} } // getEffectiveModelID returns the currently active model ID for an agent, accounting // for any active fallback cooldown. During a cooldown period, this returns the fallback // model ID instead of the configured primary model, so the UI reflects the actual model in use. -func (r *LocalRuntime) getEffectiveModelID(a *agent.Agent) string { +func (r *LocalRuntime) getEffectiveModelID(a *agent.Agent) modelsdev.ID { cooldownState := r.fallback.cooldowns.Get(a.Name()) if cooldownState != nil { fallbacks := a.FallbackModels() @@ -891,15 +892,9 @@ func (r *LocalRuntime) agentDetailsFromTeam() []AgentDetails { if a, err := r.team.Agent(info.Name); err == nil && a != nil { fallbacks := a.FallbackModels() if cooldownState.fallbackIndex >= 0 && cooldownState.fallbackIndex < len(fallbacks) { - fb := fallbacks[cooldownState.fallbackIndex] - // Parse provider/model from the fallback model ID - modelID := fb.ID() - if p, m, found := strings.Cut(modelID, "/"); found { - providerName = p - modelName = m - } else { - modelName = modelID - } + fb := fallbacks[cooldownState.fallbackIndex].ID() + providerName = fb.Provider + modelName = fb.Model } } } @@ -1018,7 +1013,7 @@ func (r *LocalRuntime) EmitStartupInfo(ctx context.Context, sess *session.Sessio // Emit agent and team information immediately for fast sidebar display // Use getEffectiveModelID to account for active fallback cooldowns modelID := r.getEffectiveModelID(a) - if !send(AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage())) { + if !send(AgentInfo(a.Name(), modelID.String(), a.Description(), a.WelcomeMessage())) { return } if !send(TeamInfo(r.agentDetailsFromTeam(), r.CurrentAgentName())) { diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 03d0f5649..d67459585 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -159,7 +159,7 @@ type mockProvider struct { stream chat.MessageStream } -func (m *mockProvider) ID() string { return m.id } +func (m *mockProvider) ID() modelsdev.ID { return modelsdev.ParseIDOrZero(m.id) } func (m *mockProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { return m.stream, nil @@ -173,7 +173,7 @@ type mockProviderWithError struct { id string } -func (m *mockProviderWithError) ID() string { return m.id } +func (m *mockProviderWithError) ID() modelsdev.ID { return modelsdev.ParseIDOrZero(m.id) } func (m *mockProviderWithError) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { return nil, errors.New("simulated error creating chat completion stream") @@ -187,7 +187,7 @@ type mockModelStore struct { ModelStore } -func (m mockModelStore) GetModel(_ context.Context, _ string) (*modelsdev.Model, error) { +func (m mockModelStore) GetModel(_ context.Context, _ modelsdev.ID) (*modelsdev.Model, error) { return nil, nil } @@ -645,7 +645,7 @@ type queueProvider struct { streams []chat.MessageStream } -func (p *queueProvider) ID() string { return p.id } +func (p *queueProvider) ID() modelsdev.ID { return modelsdev.ParseIDOrZero(p.id) } func (p *queueProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { p.mu.Lock() @@ -668,7 +668,7 @@ type mockModelStoreWithLimit struct { limit int } -func (m mockModelStoreWithLimit) GetModel(_ context.Context, _ string) (*modelsdev.Model, error) { +func (m mockModelStoreWithLimit) GetModel(_ context.Context, _ modelsdev.ID) (*modelsdev.Model, error) { return &modelsdev.Model{Limit: modelsdev.Limit{Context: m.limit}, Cost: &modelsdev.Cost{}}, nil } @@ -724,7 +724,7 @@ type errorProvider struct { err error } -func (p *errorProvider) ID() string { return p.id } +func (p *errorProvider) ID() modelsdev.ID { return modelsdev.ParseIDOrZero(p.id) } func (p *errorProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { return nil, p.err @@ -2526,7 +2526,7 @@ type recordingProvider struct { recordedCalls [][]tools.Tool // tools passed on each call } -func (r *recordingProvider) ID() string { return r.id } +func (r *recordingProvider) ID() modelsdev.ID { return modelsdev.ParseIDOrZero(r.id) } func (r *recordingProvider) CreateChatCompletionStream(_ context.Context, _ []chat.Message, toolList []tools.Tool) (chat.MessageStream, error) { r.mu.Lock() @@ -2732,7 +2732,7 @@ type messageRecordingProvider struct { recordedMessages [][]chat.Message // messages passed on each call } -func (p *messageRecordingProvider) ID() string { return p.id } +func (p *messageRecordingProvider) ID() modelsdev.ID { return modelsdev.ParseIDOrZero(p.id) } func (p *messageRecordingProvider) CreateChatCompletionStream(_ context.Context, msgs []chat.Message, _ []tools.Tool) (chat.MessageStream, error) { p.mu.Lock() @@ -2988,7 +2988,7 @@ type steerInjectProvider struct { mu sync.Mutex } -func (p *steerInjectProvider) ID() string { return p.id } +func (p *steerInjectProvider) ID() modelsdev.ID { return modelsdev.ParseIDOrZero(p.id) } func (p *steerInjectProvider) CreateChatCompletionStream(_ context.Context, _ []chat.Message, _ []tools.Tool) (chat.MessageStream, error) { p.mu.Lock() diff --git a/pkg/runtime/strip_modalities.go b/pkg/runtime/strip_modalities.go index 13456804f..404100f8f 100644 --- a/pkg/runtime/strip_modalities.go +++ b/pkg/runtime/strip_modalities.go @@ -7,6 +7,7 @@ import ( "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/hooks" + "github.com/docker/docker-agent/pkg/modelsdev" ) // BuiltinStripUnsupportedModalities is the name of the runtime-shipped @@ -54,7 +55,13 @@ func (r *LocalRuntime) stripUnsupportedModalitiesTransform( slog.DebugContext(ctx, "strip_unsupported_modalities: skipping, no ModelID on input") return msgs, nil } - m, err := r.modelsStore.GetModel(ctx, in.ModelID) + id, err := modelsdev.ParseID(in.ModelID) + if err != nil { + slog.DebugContext(ctx, "strip_unsupported_modalities: skipping, invalid ModelID", + "model_id", in.ModelID, "error", err) + return msgs, nil + } + m, err := r.modelsStore.GetModel(ctx, id) if err != nil || m == nil { // Unknown model: keep the previous (inline) behavior of // passing messages through untouched. The model call will diff --git a/pkg/runtime/transforms_test.go b/pkg/runtime/transforms_test.go index 0b6076f1b..9b7af67c6 100644 --- a/pkg/runtime/transforms_test.go +++ b/pkg/runtime/transforms_test.go @@ -29,7 +29,7 @@ type modalityModelStore struct { err error } -func (m modalityModelStore) GetModel(_ context.Context, _ string) (*modelsdev.Model, error) { +func (m modalityModelStore) GetModel(_ context.Context, _ modelsdev.ID) (*modelsdev.Model, error) { return m.model, m.err } @@ -43,8 +43,8 @@ type modalityByIDStore struct { models map[string]*modelsdev.Model } -func (m modalityByIDStore) GetModel(_ context.Context, id string) (*modelsdev.Model, error) { - return m.models[id], nil +func (m modalityByIDStore) GetModel(_ context.Context, id modelsdev.ID) (*modelsdev.Model, error) { + return m.models[id.String()], nil } // recordingMsgProvider captures the messages each model call sees so diff --git a/pkg/runtime/turn_end_test.go b/pkg/runtime/turn_end_test.go index c3e798e06..7644d5cfc 100644 --- a/pkg/runtime/turn_end_test.go +++ b/pkg/runtime/turn_end_test.go @@ -13,6 +13,7 @@ import ( "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/hooks" "github.com/docker/docker-agent/pkg/model/provider/base" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/team" "github.com/docker/docker-agent/pkg/tools" @@ -207,7 +208,7 @@ type blockingProvider struct { id string } -func (p *blockingProvider) ID() string { return p.id } +func (p *blockingProvider) ID() modelsdev.ID { return modelsdev.ParseIDOrZero(p.id) } func (p *blockingProvider) CreateChatCompletionStream(ctx context.Context, _ []chat.Message, _ []tools.Tool) (chat.MessageStream, error) { // Snapshot ctx.Done() at stream-construction time — the runtime diff --git a/pkg/runtime/with_agent_model_test.go b/pkg/runtime/with_agent_model_test.go index ac4a97f34..1282a9333 100644 --- a/pkg/runtime/with_agent_model_test.go +++ b/pkg/runtime/with_agent_model_test.go @@ -70,7 +70,7 @@ func TestWithAgentModel(t *testing.T) { userPick := &mockProvider{id: "user/pick"} root := agent.New("root", "test", agent.WithModel(&mockProvider{id: "default/model"})) root.SetModelOverride(userPick) - require.Equal(t, "user/pick", root.Model(t.Context()).ID()) + require.Equal(t, "user/pick", root.Model(t.Context()).ID().String()) tm := team.New(team.WithAgents(root)) r := &LocalRuntime{ @@ -86,12 +86,12 @@ func TestWithAgentModel(t *testing.T) { // Inside the scope: override is cleared. assert.False(t, root.HasModelOverride()) - assert.Equal(t, "default/model", root.Model(t.Context()).ID()) + assert.Equal(t, "default/model", root.Model(t.Context()).ID().String()) // After restore: user's pick is back. restore() assert.True(t, root.HasModelOverride()) - assert.Equal(t, "user/pick", root.Model(t.Context()).ID()) + assert.Equal(t, "user/pick", root.Model(t.Context()).ID().String()) }) t.Run("restore is idempotent", func(t *testing.T) { @@ -110,10 +110,10 @@ func TestWithAgentModel(t *testing.T) { require.NoError(t, err) restore() - assert.Equal(t, "user/pick", root.Model(t.Context()).ID()) + assert.Equal(t, "user/pick", root.Model(t.Context()).ID().String()) // Second call is a CAS no-op (the state is already restored). assert.NotPanics(t, restore) - assert.Equal(t, "user/pick", root.Model(t.Context()).ID()) + assert.Equal(t, "user/pick", root.Model(t.Context()).ID().String()) }) t.Run("concurrent change is preserved by restore", func(t *testing.T) { @@ -140,6 +140,6 @@ func TestWithAgentModel(t *testing.T) { // Restore must be a no-op because the override changed. restore() require.True(t, root.HasModelOverride(), "concurrent change must be preserved") - assert.Equal(t, "user/pick", root.Model(t.Context()).ID()) + assert.Equal(t, "user/pick", root.Model(t.Context()).ID().String()) }) } diff --git a/pkg/sessiontitle/generator_test.go b/pkg/sessiontitle/generator_test.go index ad64f8577..300dc8bce 100644 --- a/pkg/sessiontitle/generator_test.go +++ b/pkg/sessiontitle/generator_test.go @@ -11,17 +11,18 @@ import ( "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/model/provider/base" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/tools" ) type mockProvider struct { - id string + id modelsdev.ID calls int createFn func() (chat.MessageStream, error) baseCfgFn func() base.Config } -func (p *mockProvider) ID() string { return p.id } +func (p *mockProvider) ID() modelsdev.ID { return p.id } func (p *mockProvider) CreateChatCompletionStream( _ context.Context, @@ -78,13 +79,13 @@ func TestGenerator_Generate_FallsBackOnStreamCreateError(t *testing.T) { t.Parallel() primary := &mockProvider{ - id: "primary/fail", + id: modelsdev.NewID("primary", "fail"), createFn: func() (chat.MessageStream, error) { return nil, errors.New("primary boom") }, } fallback := &mockProvider{ - id: "fallback/success", + id: modelsdev.NewID("fallback", "success"), createFn: func() (chat.MessageStream, error) { return streamWithContent("My Title"), nil }, @@ -114,13 +115,13 @@ func TestGenerator_Generate_FallsBackOnRecvError(t *testing.T) { } primary := &mockProvider{ - id: "primary/recv-error", + id: modelsdev.NewID("primary", "recv-error"), createFn: func() (chat.MessageStream, error) { return primaryStream, nil }, } fallback := &mockProvider{ - id: "fallback/success", + id: modelsdev.NewID("fallback", "success"), createFn: func() (chat.MessageStream, error) { return streamWithContent("Recovered Title"), nil }, @@ -138,13 +139,13 @@ func TestGenerator_Generate_FallsBackOnEmptyOutput(t *testing.T) { t.Parallel() primary := &mockProvider{ - id: "primary/empty", + id: modelsdev.NewID("primary", "empty"), createFn: func() (chat.MessageStream, error) { return streamWithContent("\n\n"), nil }, } fallback := &mockProvider{ - id: "fallback/success", + id: modelsdev.NewID("fallback", "success"), createFn: func() (chat.MessageStream, error) { return streamWithContent("Good Title"), nil }, diff --git a/pkg/team/team.go b/pkg/team/team.go index 5b6cb6efd..dc1ad87a0 100644 --- a/pkg/team/team.go +++ b/pkg/team/team.go @@ -65,13 +65,9 @@ func (t *Team) AgentsInfo() []AgentInfo { Commands: a.Commands(), } if model := a.Model(context.TODO()); model != nil { - modelID := model.ID() - if prov, modelName, found := strings.Cut(modelID, "/"); found { - info.Provider = prov - info.Model = modelName - } else { - info.Model = modelID - } + id := model.ID() + info.Provider = id.Provider + info.Model = id.Model } infos = append(infos, info) } diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index acddc7791..9d4d5029c 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -20,6 +20,7 @@ import ( "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/model/provider/dmr" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/permissions" "github.com/docker/docker-agent/pkg/skills" "github.com/docker/docker-agent/pkg/team" @@ -306,7 +307,7 @@ func getModelsForAgent(ctx context.Context, cfg *latest.Config, a *latest.AgentC if modelCfg.MaxTokens != nil { maxTokens = modelCfg.MaxTokens } else if modelsStoreErr == nil { - m, err := modelsStore.GetModel(ctx, modelCfg.Provider+"/"+modelCfg.Model) + m, err := modelsStore.GetModel(ctx, modelsdev.NewID(modelCfg.Provider, modelCfg.Model)) if err == nil { maxTokens = &m.Limit.Output } @@ -369,7 +370,7 @@ func getFallbackModelsForAgent(ctx context.Context, cfg *latest.Config, a *lates if modelCfg.MaxTokens != nil { maxTokens = modelCfg.MaxTokens } else if modelsStoreErr == nil { - m, err := modelsStore.GetModel(ctx, modelCfg.Provider+"/"+modelCfg.Model) + m, err := modelsStore.GetModel(ctx, modelsdev.NewID(modelCfg.Provider, modelCfg.Model)) if err == nil { maxTokens = &m.Limit.Output } diff --git a/pkg/teamloader/teamloader_test.go b/pkg/teamloader/teamloader_test.go index 760fc18d0..60dca78f8 100644 --- a/pkg/teamloader/teamloader_test.go +++ b/pkg/teamloader/teamloader_test.go @@ -196,7 +196,7 @@ func TestOverrideModel(t *testing.T) { require.NoError(t, err) rootAgent, err := team.Agent("root") require.NoError(t, err) - require.Equal(t, test.expected, rootAgent.Model(t.Context()).ID()) + require.Equal(t, test.expected, rootAgent.Model(t.Context()).ID().String()) } }) }