diff --git a/cmd/wasm/runtime_wasm.go b/cmd/wasm/runtime_wasm.go index 297f02a39..371fc2fd2 100644 --- a/cmd/wasm/runtime_wasm.go +++ b/cmd/wasm/runtime_wasm.go @@ -125,7 +125,7 @@ func buildRuntime(ctx context.Context, cfg *latest.Config, env environment.Provi case ts.Type == "filesystem" || ts.Type == "": // Default toolset type is filesystem. if ts.Type == "filesystem" { - fsTool := filesystem.NewFilesystemTool("/") + fsTool := filesystem.New("/") opts = append(opts, agent.WithToolSets(fsTool)) } case ts.Remote.URL != "": diff --git a/examples/golibrary/builtintool/main.go b/examples/golibrary/builtintool/main.go index 912aca6e3..ddbdea1f2 100644 --- a/examples/golibrary/builtintool/main.go +++ b/examples/golibrary/builtintool/main.go @@ -47,7 +47,7 @@ func run(ctx context.Context) error { "root", "You are an expert hacker", agent.WithModel(llm), - agent.WithToolSets(shell.NewShellTool(os.Environ(), &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})), + agent.WithToolSets(shell.New(os.Environ(), &config.RuntimeConfig{Config: config.Config{WorkingDir: "/tmp"}})), ), ), ) diff --git a/examples/golibrary/multi/main.go b/examples/golibrary/multi/main.go index f493d9e57..74905c4e4 100644 --- a/examples/golibrary/multi/main.go +++ b/examples/golibrary/multi/main.go @@ -49,7 +49,7 @@ func run(ctx context.Context) error { "You are a human, with feelings and emotions.", agent.WithModel(llm), agent.WithSubAgents(child), - agent.WithToolSets(transfertask.NewTransferTaskTool()), + agent.WithToolSets(transfertask.New()), ) rt, err := runtime.New(team.New(team.WithAgents(root, child))) if err != nil { diff --git a/pkg/acp/filesystem.go b/pkg/acp/filesystem.go index efdb47366..fb33fef85 100644 --- a/pkg/acp/filesystem.go +++ b/pkg/acp/filesystem.go @@ -32,7 +32,7 @@ func getSessionID(ctx context.Context) (string, bool) { // FilesystemToolset wraps a standard Tool and overrides read_file, write_file, // and edit_file to use the ACP connection for file operations type FilesystemToolset struct { - *filesystem.Tool + *filesystem.ToolSet agent *Agent workingDir string @@ -43,7 +43,7 @@ var _ tools.ToolSet = (*FilesystemToolset)(nil) // NewFilesystemToolset creates a new ACP-specific filesystem toolset func NewFilesystemToolset(agent *Agent, workingDir string, opts ...filesystem.Opt) *FilesystemToolset { return &FilesystemToolset{ - Tool: filesystem.NewFilesystemTool(workingDir, opts...), + ToolSet: filesystem.New(workingDir, opts...), agent: agent, workingDir: workingDir, } @@ -51,7 +51,7 @@ func NewFilesystemToolset(agent *Agent, workingDir string, opts ...filesystem.Op // Tools returns the tool definitions with ACP-specific overrides func (t *FilesystemToolset) Tools(ctx context.Context) ([]tools.Tool, error) { - baseTools, err := t.Tool.Tools(ctx) + baseTools, err := t.ToolSet.Tools(ctx) if err != nil { return nil, err } diff --git a/pkg/acp/registry.go b/pkg/acp/registry.go index fb9c7af2c..99b532aac 100644 --- a/pkg/acp/registry.go +++ b/pkg/acp/registry.go @@ -11,10 +11,20 @@ import ( ) // createToolsetRegistry creates a custom toolset registry with ACP-specific filesystem toolset -func createToolsetRegistry(agent *Agent) *teamloader.ToolsetRegistry { - registry := teamloader.NewDefaultToolsetRegistry() +func createToolsetRegistry(agent *Agent) teamloader.ToolsetRegistry { + return &acpToolsetRegistry{ + agent: agent, + registry: teamloader.NewDefaultToolsetRegistry(), + } +} + +type acpToolsetRegistry struct { + agent *Agent + registry teamloader.ToolsetRegistry +} - registry.Register("filesystem", func(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { +func (r *acpToolsetRegistry) CreateTool(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, agentName string) (tools.ToolSet, error) { + if toolset.Type == "filesystem" { wd := runConfig.WorkingDir if wd == "" { var err error @@ -24,8 +34,8 @@ func createToolsetRegistry(agent *Agent) *teamloader.ToolsetRegistry { } } - return NewFilesystemToolset(agent, wd), nil - }) + return NewFilesystemToolset(r.agent, wd), nil + } - return registry + return r.registry.CreateTool(ctx, toolset, parentDir, runConfig, agentName) } diff --git a/pkg/app/app_test.go b/pkg/app/app_test.go index 7f2304c5f..cb8fbcf29 100644 --- a/pkg/app/app_test.go +++ b/pkg/app/app_test.go @@ -60,7 +60,7 @@ func (m *mockRuntime) SessionStore() session.Store { return m.store } func (m *mockRuntime) Summarize(ctx context.Context, sess *session.Session, additionalPrompt string, events runtime.EventSink) { } func (m *mockRuntime) PermissionsInfo() *runtime.PermissionsInfo { return nil } -func (m *mockRuntime) CurrentAgentSkillsToolset() *skillstool.Toolset { +func (m *mockRuntime) CurrentAgentSkillsToolset() *skillstool.ToolSet { return nil } diff --git a/pkg/cli/runner_test.go b/pkg/cli/runner_test.go index 4c14f7b63..e2dba1ece 100644 --- a/pkg/cli/runner_test.go +++ b/pkg/cli/runner_test.go @@ -51,7 +51,7 @@ func (m *mockRuntime) ResumeElicitation(_ context.Context, action tools.Elicitat func (m *mockRuntime) SessionStore() session.Store { return nil } func (m *mockRuntime) Summarize(context.Context, *session.Session, string, runtime.EventSink) {} func (m *mockRuntime) PermissionsInfo() *runtime.PermissionsInfo { return nil } -func (m *mockRuntime) CurrentAgentSkillsToolset() *skillstool.Toolset { return nil } +func (m *mockRuntime) CurrentAgentSkillsToolset() *skillstool.ToolSet { return nil } func (m *mockRuntime) CurrentMCPPrompts(context.Context) map[string]mcptools.PromptInfo { return nil } diff --git a/pkg/runtime/commands_test.go b/pkg/runtime/commands_test.go index 2a42409dd..db0f534b0 100644 --- a/pkg/runtime/commands_test.go +++ b/pkg/runtime/commands_test.go @@ -54,7 +54,7 @@ func (m *mockRuntime) SessionStore() session.Store { return nil } func (m *mockRuntime) Summarize(context.Context, *session.Session, string, EventSink) { } func (m *mockRuntime) PermissionsInfo() *PermissionsInfo { return nil } -func (m *mockRuntime) CurrentAgentSkillsToolset() *skillstool.Toolset { +func (m *mockRuntime) CurrentAgentSkillsToolset() *skillstool.ToolSet { return nil } diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index a0a441970..81201650c 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -841,7 +841,7 @@ func (r *LocalRuntime) configureToolsetHandlers(a *agent.Agent, events EventSink // channel; a blocking send after the channel is closed would // crash, and a blocking send when the consumer has gone away // would deadlock. - if ragTool, ok := tools.As[*builtinrag.Tool](toolset); ok { + if ragTool, ok := tools.As[*builtinrag.ToolSet](toolset); ok { ragTool.SetEventCallback(ragEventForwarder(ragTool.Name(), r, nonBlocking(events).Emit)) } } diff --git a/pkg/runtime/model_picker.go b/pkg/runtime/model_picker.go index fb625643e..c86204c2a 100644 --- a/pkg/runtime/model_picker.go +++ b/pkg/runtime/model_picker.go @@ -15,14 +15,14 @@ import ( // findModelPickerTool returns the Tool from the current agent's // toolsets, or nil if the agent has no model_picker configured. -func (r *LocalRuntime) findModelPickerTool() *modelpicker.Tool { +func (r *LocalRuntime) findModelPickerTool() *modelpicker.ToolSet { currentName := r.CurrentAgentName() a, err := r.team.Agent(currentName) if err != nil { return nil } for _, ts := range a.ToolSets() { - if mpt, ok := tools.As[*modelpicker.Tool](ts); ok { + if mpt, ok := tools.As[*modelpicker.ToolSet](ts); ok { return mpt } } diff --git a/pkg/runtime/remote_runtime.go b/pkg/runtime/remote_runtime.go index dcbe1dea8..cdd14495e 100644 --- a/pkg/runtime/remote_runtime.go +++ b/pkg/runtime/remote_runtime.go @@ -567,7 +567,7 @@ func (r *RemoteRuntime) ResetStartupInfo() { } // CurrentAgentSkillsToolset returns nil for remote runtimes since skills are managed server-side. -func (r *RemoteRuntime) CurrentAgentSkillsToolset() *skills.Toolset { +func (r *RemoteRuntime) CurrentAgentSkillsToolset() *skills.ToolSet { return nil } diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 7e9fccdd8..eec5b8e2d 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -83,7 +83,7 @@ type Runtime interface { PermissionsInfo() *PermissionsInfo // CurrentAgentSkillsToolset returns the skills toolset for the current agent, or nil if skills are not enabled. - CurrentAgentSkillsToolset() *skills.Toolset + CurrentAgentSkillsToolset() *skills.ToolSet // CurrentMCPPrompts returns MCP prompts available from the current agent's toolsets. // Returns an empty map if no MCP prompts are available. @@ -779,13 +779,13 @@ func (r *LocalRuntime) resolveSessionAgent(sess *session.Session) *agent.Agent { } // CurrentAgentSkillsToolset returns the skills toolset for the current agent, or nil if not enabled. -func (r *LocalRuntime) CurrentAgentSkillsToolset() *skills.Toolset { +func (r *LocalRuntime) CurrentAgentSkillsToolset() *skills.ToolSet { a := r.CurrentAgent() if a == nil { return nil } for _, ts := range a.ToolSets() { - if st, ok := tools.As[*skills.Toolset](ts); ok { + if st, ok := tools.As[*skills.ToolSet](ts); ok { return st } } diff --git a/pkg/runtime/skill_runner.go b/pkg/runtime/skill_runner.go index 5f5fbd2fa..dede6e7ea 100644 --- a/pkg/runtime/skill_runner.go +++ b/pkg/runtime/skill_runner.go @@ -21,7 +21,7 @@ import ( // its final response is returned as the tool result. // // All skill-specific business rules (lookup, fork-mode validation, content -// expansion) live in (*skills.Toolset).PrepareForkSubSession; this +// expansion) live in (*skills.ToolSet).PrepareForkSubSession; this // handler keeps only the runtime-private orchestration that runForwarding // can't generalise — namely the optional model override that applies for // the sub-session's lifetime. diff --git a/pkg/session/session_test.go b/pkg/session/session_test.go index 8d3131043..66612597e 100644 --- a/pkg/session/session_test.go +++ b/pkg/session/session_test.go @@ -10,10 +10,18 @@ import ( "github.com/docker/docker-agent/pkg/agent" "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/tools/builtin/todo" ) +func todoToolSet(t *testing.T) tools.ToolSet { + t.Helper() + toolSet, err := todo.CreateToolSet(latest.Toolset{}) + require.NoError(t, err) + return toolSet +} + func TestTrimMessagesWithToolCalls(t *testing.T) { messages := []chat.Message{ { @@ -173,7 +181,7 @@ func TestGetMessages_Instructions(t *testing.T) { } func TestGetMessages_CacheControl(t *testing.T) { - testAgent := agent.New("root", "instructions", agent.WithToolSets(&todo.Tool{})) + testAgent := agent.New("root", "instructions", agent.WithToolSets(todoToolSet(t))) s := New() messages := s.GetMessages(testAgent) @@ -197,7 +205,7 @@ func TestGetMessages_CacheControlWithSummary(t *testing.T) { // buildContextSpecificSystemMessages caching behavior. // - Summary and conversation messages are not cache-controlled. testAgent := agent.New("root", "instructions", - agent.WithToolSets(&todo.Tool{}), + agent.WithToolSets(todoToolSet(t)), ) s := New() diff --git a/pkg/teamloader/lifecycle.go b/pkg/teamloader/lifecycle.go deleted file mode 100644 index 35ad45244..000000000 --- a/pkg/teamloader/lifecycle.go +++ /dev/null @@ -1,84 +0,0 @@ -package teamloader - -import ( - "log/slog" - - "github.com/docker/docker-agent/pkg/config/latest" - "github.com/docker/docker-agent/pkg/tools/lifecycle" -) - -// lifecyclePolicyFromConfig converts a latest.LifecycleConfig into a -// lifecycle.Policy. nil cfg returns the resilient default policy. -// -// Resolution order: profile defaults first, then explicit field overrides. -// The Logger field is always populated with a component-tagged slog so -// supervisor messages identify which toolset produced them. -func lifecyclePolicyFromConfig(name string, cfg *latest.LifecycleConfig) lifecycle.Policy { - policy := profilePolicy(profileName(cfg)) - policy.Logger = slog.With("component", "supervisor", "toolset", name) - - if cfg == nil { - return policy - } - if cfg.Restart != "" { - policy.Restart = parseRestart(cfg.Restart) - } - if cfg.MaxRestarts != 0 { - // 0 keeps the profile default; -1 means unlimited (in both this - // config and the supervisor). - policy.MaxAttempts = cfg.MaxRestarts - } - if b := cfg.Backoff; b != nil { - if b.Initial.Duration > 0 { - policy.Backoff.Initial = b.Initial.Duration - } - if b.Max.Duration > 0 { - policy.Backoff.Max = b.Max.Duration - } - if b.Multiplier > 0 { - policy.Backoff.Multiplier = b.Multiplier - } - if b.Jitter > 0 { - policy.Backoff.Jitter = b.Jitter - } - } - return policy -} - -// profileName returns the effective profile name, defaulting to -// "resilient" when cfg is nil or its Profile field is empty. -func profileName(cfg *latest.LifecycleConfig) string { - if cfg == nil || cfg.Profile == "" { - return latest.LifecycleProfileResilient - } - return cfg.Profile -} - -// profilePolicy returns the lifecycle.Policy defaults for a profile name. -// Strict and best-effort produce the same supervisor policy (no restart); -// they differ in the Required flag which is documented but not yet -// enforced by the runtime. -func profilePolicy(profile string) lifecycle.Policy { - switch profile { - case latest.LifecycleProfileStrict, latest.LifecycleProfileBestEffort: - // MaxAttempts is moot when Restart=Never; -1 keeps it explicit. - return lifecycle.Policy{Restart: lifecycle.RestartNever, MaxAttempts: -1} - default: // resilient (default + fallback for unknown names; the - // validator already rejects unknown names). - return lifecycle.Policy{Restart: lifecycle.RestartOnFailure, MaxAttempts: 5} - } -} - -// parseRestart converts a YAML restart string into the supervisor enum. -// Unknown values fall back to RestartOnFailure (the validator rejects -// them upstream). -func parseRestart(s string) lifecycle.Restart { - switch s { - case "never": - return lifecycle.RestartNever - case "always": - return lifecycle.RestartAlways - default: - return lifecycle.RestartOnFailure - } -} diff --git a/pkg/teamloader/lifecycle_test.go b/pkg/teamloader/lifecycle_test.go deleted file mode 100644 index ab2accf43..000000000 --- a/pkg/teamloader/lifecycle_test.go +++ /dev/null @@ -1,82 +0,0 @@ -package teamloader - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" - - "github.com/docker/docker-agent/pkg/config/latest" - "github.com/docker/docker-agent/pkg/tools/lifecycle" -) - -func TestLifecyclePolicyFromConfig_NilUsesResilientDefaults(t *testing.T) { - t.Parallel() - p := lifecyclePolicyFromConfig("test", nil) - assert.Equal(t, lifecycle.RestartOnFailure, p.Restart) - assert.Equal(t, 5, p.MaxAttempts) - assert.NotNil(t, p.Logger) -} - -func TestLifecyclePolicyFromConfig_StrictProfile(t *testing.T) { - t.Parallel() - p := lifecyclePolicyFromConfig("test", &latest.LifecycleConfig{ - Profile: latest.LifecycleProfileStrict, - }) - assert.Equal(t, lifecycle.RestartNever, p.Restart) - assert.Equal(t, -1, p.MaxAttempts) -} - -func TestLifecyclePolicyFromConfig_BestEffortProfile(t *testing.T) { - t.Parallel() - p := lifecyclePolicyFromConfig("test", &latest.LifecycleConfig{ - Profile: latest.LifecycleProfileBestEffort, - }) - assert.Equal(t, lifecycle.RestartNever, p.Restart) -} - -func TestLifecyclePolicyFromConfig_ExplicitOverrides(t *testing.T) { - t.Parallel() - cfg := &latest.LifecycleConfig{ - Profile: latest.LifecycleProfileResilient, - Restart: "always", - MaxRestarts: 12, - Backoff: &latest.BackoffConfig{ - Initial: latest.Duration{Duration: 500 * time.Millisecond}, - Max: latest.Duration{Duration: 10 * time.Second}, - Multiplier: 1.5, - Jitter: 0.3, - }, - } - p := lifecyclePolicyFromConfig("test", cfg) - assert.Equal(t, lifecycle.RestartAlways, p.Restart) - assert.Equal(t, 12, p.MaxAttempts) - assert.Equal(t, 500*time.Millisecond, p.Backoff.Initial) - assert.Equal(t, 10*time.Second, p.Backoff.Max) - assert.InDelta(t, 1.5, p.Backoff.Multiplier, 0.001) - assert.InDelta(t, 0.3, p.Backoff.Jitter, 0.001) -} - -func TestLifecyclePolicyFromConfig_PartialOverridesKeepProfileDefaults(t *testing.T) { - t.Parallel() - cfg := &latest.LifecycleConfig{ - Profile: latest.LifecycleProfileResilient, - MaxRestarts: 7, - } - p := lifecyclePolicyFromConfig("test", cfg) - assert.Equal(t, lifecycle.RestartOnFailure, p.Restart, "profile default preserved") - assert.Equal(t, 7, p.MaxAttempts, "explicit override applied") -} - -func TestParseRestart(t *testing.T) { - t.Parallel() - cases := map[string]lifecycle.Restart{ - "": lifecycle.RestartOnFailure, - "on_failure": lifecycle.RestartOnFailure, - "never": lifecycle.RestartNever, - "always": lifecycle.RestartAlways, - } - for in, want := range cases { - assert.Equal(t, want, parseRestart(in), "input=%q", in) - } -} diff --git a/pkg/teamloader/registry.go b/pkg/teamloader/registry.go index 593d44612..4b8714afe 100644 --- a/pkg/teamloader/registry.go +++ b/pkg/teamloader/registry.go @@ -3,23 +3,10 @@ package teamloader import ( "cmp" "context" - "errors" "fmt" - "log/slog" - "os" - "path/filepath" - "time" "github.com/docker/docker-agent/pkg/config" "github.com/docker/docker-agent/pkg/config/latest" - "github.com/docker/docker-agent/pkg/environment" - "github.com/docker/docker-agent/pkg/gateway" - "github.com/docker/docker-agent/pkg/js" - "github.com/docker/docker-agent/pkg/memory/database/sqlite" - "github.com/docker/docker-agent/pkg/path" - "github.com/docker/docker-agent/pkg/paths" - "github.com/docker/docker-agent/pkg/rag" - "github.com/docker/docker-agent/pkg/toolinstall" "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/tools/a2a" agenttool "github.com/docker/docker-agent/pkg/tools/builtin/agent" @@ -43,29 +30,74 @@ import ( // configName identifies the agent config file (e.g. "memory_agent" from "memory_agent.yaml"). type ToolsetCreator func(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, configName string) (tools.ToolSet, error) -// ToolsetRegistry manages the registration of toolset creators by type -type ToolsetRegistry struct { +// ToolsetRegistry manages the registration of toolset creators by type. +type ToolsetRegistry interface { + CreateTool(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, agentName string) (tools.ToolSet, error) +} + +func NewDefaultToolsetRegistry() ToolsetRegistry { + return &toolsetRegistry{ + creators: map[string]ToolsetCreator{ + "todo": func(_ context.Context, toolset latest.Toolset, _ string, _ *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + return todo.CreateToolSet(toolset) + }, + "tasks": func(_ context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + return tasks.CreateToolSet(toolset, parentDir, runConfig) + }, + "memory": func(_ context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, configName string) (tools.ToolSet, error) { + return memory.CreateToolSet(toolset, parentDir, runConfig, configName) + }, + "think": func(_ context.Context, _ latest.Toolset, _ string, _ *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + return think.CreateToolSet() + }, + "shell": func(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + return shell.CreateToolSet(ctx, toolset, runConfig) + }, + "script": func(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + return shell.CreateScriptToolSet(ctx, toolset, runConfig) + }, + "filesystem": func(_ context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + return filesystem.CreateToolSet(toolset, runConfig) + }, + "fetch": func(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + return fetch.CreateToolSet(ctx, toolset, runConfig) + }, + "mcp": func(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + return mcp.CreateToolSet(ctx, toolset, runConfig) + }, + "api": func(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + return api.CreateToolSet(ctx, toolset, runConfig) + }, + "a2a": func(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + return a2a.CreateToolSet(ctx, toolset, runConfig) + }, + "lsp": func(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + return lsp.CreateToolSet(ctx, toolset, runConfig) + }, + "user_prompt": func(_ context.Context, _ latest.Toolset, _ string, _ *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + return userprompt.CreateToolSet() + }, + "openapi": func(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + return openapi.CreateToolSet(ctx, toolset, runConfig) + }, + "model_picker": func(_ context.Context, toolset latest.Toolset, _ string, _ *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + return modelpicker.CreateToolSet(toolset) + }, + "background_agents": func(_ context.Context, _ latest.Toolset, _ string, _ *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + return agenttool.CreateToolSet() + }, + "rag": func(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + return builtinrag.CreateToolSet(ctx, toolset, parentDir, runConfig) + }, + }, + } +} + +// toolsetRegistry manages the registration of toolset creators by type. +type toolsetRegistry struct { creators map[string]ToolsetCreator } -// NewToolsetRegistry creates a new empty toolset registry -func NewToolsetRegistry() *ToolsetRegistry { - return &ToolsetRegistry{ - creators: make(map[string]ToolsetCreator), - } -} - -// Register adds a new toolset creator for the given type -func (r *ToolsetRegistry) Register(toolsetType string, creator ToolsetCreator) { - r.creators[toolsetType] = creator -} - -// Get retrieves a toolset creator for the given type -func (r *ToolsetRegistry) Get(toolsetType string) (ToolsetCreator, bool) { - creator, ok := r.creators[toolsetType] - return creator, ok -} - // CreateTool creates a toolset using the registered creator for the given type. // // Every successful toolset is decorated with tools.WithName so status @@ -74,8 +106,8 @@ func (r *ToolsetRegistry) Get(toolsetType string) (ToolsetCreator, bool) { // already advertise a non-empty Name(): it only fills the gap left by // built-in toolsets that don't take a `name:` field in YAML, replacing // the previous fallback to fmt.Sprintf("%T", ts). -func (r *ToolsetRegistry) CreateTool(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, agentName string) (tools.ToolSet, error) { - creator, ok := r.Get(toolset.Type) +func (r *toolsetRegistry) CreateTool(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, agentName string) (tools.ToolSet, error) { + creator, ok := r.creators[toolset.Type] if !ok { return nil, fmt.Errorf("unknown toolset type: %s", toolset.Type) } @@ -85,450 +117,3 @@ func (r *ToolsetRegistry) CreateTool(ctx context.Context, toolset latest.Toolset } return tools.WithName(ts, cmp.Or(toolset.Name, toolset.Type)), nil } - -func NewDefaultToolsetRegistry() *ToolsetRegistry { - r := NewToolsetRegistry() - // Register all built-in toolset creators - r.Register("todo", createTodoTool) - r.Register("tasks", createTasksTool) - r.Register("memory", createMemoryTool) - r.Register("think", createThinkTool) - r.Register("shell", createShellTool) - r.Register("script", createScriptTool) - r.Register("filesystem", createFilesystemTool) - r.Register("fetch", createFetchTool) - r.Register("mcp", createMCPTool) - r.Register("api", createAPITool) - r.Register("a2a", createA2ATool) - r.Register("lsp", createLSPTool) - r.Register("user_prompt", createUserPromptTool) - r.Register("openapi", createOpenAPITool) - r.Register("model_picker", createModelPickerTool) - r.Register("background_agents", createBackgroundAgentsTool) - r.Register("rag", createRAGTool) - return r -} - -// checkDirExists returns an error if the given directory does not exist or is -// not a directory. toolsetType is used only in the error message. -func checkDirExists(dir, toolsetType string) error { - info, err := os.Stat(dir) - if err != nil { - if os.IsNotExist(err) { - return fmt.Errorf("working_dir %q for %s toolset does not exist", dir, toolsetType) - } - return fmt.Errorf("working_dir %q for %s toolset: %w", dir, toolsetType, err) - } - if !info.IsDir() { - return fmt.Errorf("working_dir %q for %s toolset is not a directory", dir, toolsetType) - } - return nil -} - -// resolveToolsetWorkingDir returns the effective working directory for a toolset process. -// -// Resolution rules: -// - If toolsetWorkingDir is empty, agentWorkingDir is returned unchanged. -// - Shell patterns (~ and ${VAR}/$VAR) are expanded before any further processing. -// - If the expanded path is absolute, it is returned as-is. -// - If the expanded path is relative and agentWorkingDir is non-empty, -// it is joined with agentWorkingDir and made absolute via filepath.Abs. -// - If the expanded path is relative and agentWorkingDir is empty, -// the relative path is returned unchanged (caller will inherit the process cwd). -// -// Note: unlike resolveToolsetPath, this helper does not enforce containment -// within the agent working directory. working_dir is treated like command/args — -// a trusted, operator-authored value where cross-tree references (e.g. a sibling -// module root in a monorepo) are intentional and must not be silently blocked. -func resolveToolsetWorkingDir(toolsetWorkingDir, agentWorkingDir string) string { - if toolsetWorkingDir == "" { - return agentWorkingDir - } - // Expand ~ and environment variables before path operations. - toolsetWorkingDir = path.ExpandPath(toolsetWorkingDir) - if filepath.IsAbs(toolsetWorkingDir) { - return toolsetWorkingDir - } - if agentWorkingDir != "" { - // filepath.Abs cleans the result and anchors the URI correctly - // (avoids file://./backend-style LSP root URIs when the agent dir - // is itself absolute, which is the normal case). - abs, err := filepath.Abs(filepath.Join(agentWorkingDir, toolsetWorkingDir)) - if err == nil { - return abs - } - // Fallback: return the joined path without Abs (should not happen in practice). - return filepath.Join(agentWorkingDir, toolsetWorkingDir) - } - // agentWorkingDir is empty and path is relative: return as-is. - // The child process will inherit the OS working directory. - return toolsetWorkingDir -} - -// resolveToolsetPath expands shell patterns (~, env vars) in the given path, -// then validates it relative to the working directory or parent directory. -func resolveToolsetPath(toolsetPath, parentDir string, runConfig *config.RuntimeConfig) (string, error) { - toolsetPath = path.ExpandPath(toolsetPath) - - var basePath string - if filepath.IsAbs(toolsetPath) { - basePath = "" - } else if wd := runConfig.WorkingDir; wd != "" { - basePath = wd - } else { - basePath = parentDir - } - - return path.ValidatePathInDirectory(toolsetPath, basePath) -} - -func createTodoTool(_ context.Context, toolset latest.Toolset, _ string, _ *config.RuntimeConfig, _ string) (tools.ToolSet, error) { - if toolset.Shared { - return todo.NewSharedTodoTool(), nil - } - return todo.NewTodoTool(), nil -} - -func createTasksTool(_ context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { - toolsetPath := toolset.Path - if toolsetPath == "" { - toolsetPath = "tasks.json" - } - - validatedPath, err := resolveToolsetPath(toolsetPath, parentDir, runConfig) - if err != nil { - return nil, fmt.Errorf("invalid tasks storage path: %w", err) - } - if err := os.MkdirAll(filepath.Dir(validatedPath), 0o700); err != nil { - return nil, fmt.Errorf("failed to create tasks storage directory: %w", err) - } - - return tasks.NewTasksTool(validatedPath), nil -} - -func createMemoryTool(_ context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, configName string) (tools.ToolSet, error) { - var validatedMemoryPath string - - if toolset.Path != "" { - var err error - validatedMemoryPath, err = resolveToolsetPath(toolset.Path, parentDir, runConfig) - if err != nil { - return nil, fmt.Errorf("invalid memory database path: %w", err) - } - } else { - // Default: ~/.cagent/memory//memory.db - if configName == "" { - configName = "default" - } - validatedMemoryPath = filepath.Join(paths.GetDataDir(), "memory", configName, "memory.db") - } - - if err := os.MkdirAll(filepath.Dir(validatedMemoryPath), 0o700); err != nil { - return nil, fmt.Errorf("failed to create memory database directory: %w", err) - } - - db, err := sqlite.NewMemoryDatabase(validatedMemoryPath) - if err != nil { - return nil, fmt.Errorf("failed to create memory database: %w", err) - } - - return memory.NewMemoryToolWithPath(db, validatedMemoryPath), nil -} - -func createThinkTool(_ context.Context, _ latest.Toolset, _ string, _ *config.RuntimeConfig, _ string) (tools.ToolSet, error) { - return think.NewThinkTool(), nil -} - -func createShellTool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { - env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), runConfig.EnvProvider()) - if err != nil { - return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err) - } - env = append(env, os.Environ()...) - - return shell.NewShellTool(env, runConfig), nil -} - -func createScriptTool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { - if len(toolset.Shell) == 0 { - return nil, errors.New("shell is required for script toolset") - } - - env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), runConfig.EnvProvider()) - if err != nil { - return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err) - } - env = append(env, os.Environ()...) - return shell.NewScriptShellTool(toolset.Shell, env) -} - -func createFilesystemTool(_ context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { - wd := runConfig.WorkingDir - if wd == "" { - var err error - wd, err = os.Getwd() - if err != nil { - return nil, fmt.Errorf("failed to get working directory: %w", err) - } - } - - var opts []filesystem.Opt - - // Handle ignore_vcs configuration (default to true) - ignoreVCS := true - if toolset.IgnoreVCS != nil { - ignoreVCS = *toolset.IgnoreVCS - } - opts = append(opts, filesystem.WithIgnoreVCS(ignoreVCS)) - - // Handle allow/deny lists for filesystem operations. - // An empty / nil list preserves the default behaviour (no restriction). - if len(toolset.AllowList) > 0 { - opts = append(opts, filesystem.WithAllowList(toolset.AllowList)) - } - if len(toolset.DenyList) > 0 { - opts = append(opts, filesystem.WithDenyList(toolset.DenyList)) - } - - // Handle post-edit commands - if len(toolset.PostEdit) > 0 { - postEditConfigs := make([]filesystem.PostEditConfig, len(toolset.PostEdit)) - for i, pe := range toolset.PostEdit { - postEditConfigs[i] = filesystem.PostEditConfig{ - Path: pe.Path, - Cmd: pe.Cmd, - } - } - opts = append(opts, filesystem.WithPostEditCommands(postEditConfigs)) - } - - return filesystem.NewFilesystemTool(wd, opts...), nil -} - -func createAPITool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { - if toolset.APIConfig.Endpoint == "" { - return nil, errors.New("api tool requires an endpoint in api_config") - } - - expander := js.NewJsExpander(runConfig.EnvProvider()) - toolset.APIConfig.Endpoint = expander.Expand(ctx, toolset.APIConfig.Endpoint, nil) - toolset.APIConfig.Headers = expander.ExpandMap(ctx, toolset.APIConfig.Headers) - - return api.NewAPITool(toolset.APIConfig, expander), nil -} - -func createFetchTool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { - // Expand ${env.X} in headers so secrets (API tokens, ...) can come from - // the environment instead of being inlined in YAML — same behaviour as - // openapi/a2a/mcp.remote/api headers. ExpandMap and WithHeaders are both - // nil-safe, so no guard is needed when the user hasn't configured any. - expander := js.NewJsExpander(runConfig.EnvProvider()) - - var opts []fetch.ToolOption - if toolset.Timeout > 0 { - timeout := time.Duration(toolset.Timeout) * time.Second - opts = append(opts, fetch.WithTimeout(timeout)) - } - if len(toolset.AllowedDomains) > 0 { - opts = append(opts, fetch.WithAllowedDomains(toolset.AllowedDomains)) - } - if len(toolset.BlockedDomains) > 0 { - opts = append(opts, fetch.WithBlockedDomains(toolset.BlockedDomains)) - } - if toolset.AllowPrivateIPs { - opts = append(opts, fetch.WithAllowPrivateIPs(true)) - } - opts = append(opts, fetch.WithHeaders(expander.ExpandMap(ctx, toolset.Headers))) - return fetch.NewFetchTool(opts...), nil -} - -func createMCPTool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { - envProvider := runConfig.EnvProvider() - - // Resolve the working directory once; used for all subprocess-based branches. - // Note: validation only rejects working_dir for toolsets with an explicit - // remote.url. Ref-based MCPs (e.g. ref: docker:context7) pass validation - // regardless, because their transport type is only known at runtime via the - // MCP Catalog API. If such a ref resolves to a remote server at runtime, we - // return an explicit error below rather than silently discarding the field. - cwd := resolveToolsetWorkingDir(toolset.WorkingDir, runConfig.WorkingDir) - - // S1: validate the resolved directory exists (if one was specified) so we - // surface a clear error now rather than a cryptic exec failure later. - // Skip this check for ref-based toolsets whose transport type is not yet - // known — the check would be premature and potentially wrong. - if toolset.WorkingDir != "" && toolset.Ref == "" { - if err := checkDirExists(cwd, "mcp"); err != nil { - return nil, err - } - } - - switch { - // MCP Server from the MCP Catalog, running with the MCP Gateway - case toolset.Ref != "": - mcpServerName := gateway.ParseServerRef(toolset.Ref) - serverSpec, err := gateway.ServerSpec(ctx, mcpServerName) - if err != nil { - return nil, fmt.Errorf("fetching MCP server spec for %q: %w", mcpServerName, err) - } - - // TODO(dga): until the MCP Gateway supports oauth with docker agent, we fetch the remote url and directly connect to it. - if serverSpec.Type == "remote" { - // working_dir cannot be validated at config-parse time for ref-based - // MCPs because their transport type is only known here. Return a clear - // error rather than silently discarding the field. - if toolset.WorkingDir != "" { - return nil, fmt.Errorf("working_dir is not supported for MCP toolset %q: ref %q resolves to a remote server (no local subprocess)", - toolset.Name, toolset.Ref) - } - return mcp.NewRemoteToolset(toolset.Name, serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil, nil, lifecyclePolicyFromConfig(toolset.Name, toolset.Lifecycle)), nil - } - - // The ref resolves to a local subprocess — validate the working directory now. - if toolset.WorkingDir != "" { - if err := checkDirExists(cwd, "mcp"); err != nil { - return nil, err - } - } - - env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider) - if err != nil { - return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err) - } - - envProvider := environment.NewMultiProvider( - environment.NewEnvListProvider(env), - envProvider, - ) - - // Pass the resolved cwd so gateway-based MCPs also honour working_dir. - return mcp.NewGatewayToolset(ctx, toolset.Name, mcpServerName, serverSpec.Secrets, toolset.Config, envProvider, cwd) - - // STDIO MCP Server from shell command - case toolset.Command != "": - // Auto-install missing command binary if needed. - // If EnsureCommand fails (binary not on PATH, no aqua package, etc.), - // treat as transient: create the toolset with the original command - // and let mcp.Toolset.Start() retry on each conversation turn. - resolvedCommand, err := toolinstall.EnsureCommand(ctx, toolset.Command, toolset.Version) - if err != nil { - slog.WarnContext(ctx, "MCP command not yet available, will retry on next turn", - "command", toolset.Command, "error", err) - resolvedCommand = toolset.Command - } - - env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider) - if err != nil { - return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err) - } - env = append(env, os.Environ()...) - - // Prepend tools bin dir to PATH so child processes can find installed tools - env = toolinstall.PrependBinDirToEnv(env) - - return mcp.NewToolsetCommand(toolset.Name, resolvedCommand, toolset.Args, env, cwd, lifecyclePolicyFromConfig(toolset.Name, toolset.Lifecycle)), nil - - // Remote MCP Server — working_dir is rejected at validation time for this - // branch (explicit remote.url in config). Ref-based MCPs that resolve to - // remote at runtime are handled with an explicit error in the Ref branch above. - case toolset.Remote.URL != "": - expander := js.NewJsExpander(envProvider) - - headers := expander.ExpandMap(ctx, toolset.Remote.Headers) - url := expander.Expand(ctx, toolset.Remote.URL, nil) - - return mcp.NewRemoteToolset(toolset.Name, url, toolset.Remote.TransportType, headers, toolset.Remote.OAuth, lifecyclePolicyFromConfig(toolset.Name, toolset.Lifecycle)), nil - - default: - return nil, errors.New("mcp toolset requires either ref, command, or remote configuration") - } -} - -func createA2ATool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { - expander := js.NewJsExpander(runConfig.EnvProvider()) - - headers := expander.ExpandMap(ctx, toolset.Headers) - - return a2a.NewToolset(toolset.Name, toolset.URL, headers), nil -} - -func createLSPTool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { - // Auto-install missing command binary if needed - resolvedCommand, err := toolinstall.EnsureCommand(ctx, toolset.Command, toolset.Version) - if err != nil { - return nil, fmt.Errorf("resolving command %q: %w", toolset.Command, err) - } - - env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), runConfig.EnvProvider()) - if err != nil { - return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err) - } - env = append(env, os.Environ()...) - - // Prepend tools bin dir to PATH so child processes can find installed tools - env = toolinstall.PrependBinDirToEnv(env) - - cwd := resolveToolsetWorkingDir(toolset.WorkingDir, runConfig.WorkingDir) - - // S1: validate the resolved directory exists (if one was specified) so we - // surface a clear error now rather than a cryptic exec failure later. - if toolset.WorkingDir != "" { - if err := checkDirExists(cwd, "lsp"); err != nil { - return nil, err - } - } - - tool := lsp.NewLSPTool(resolvedCommand, toolset.Args, env, cwd, lifecyclePolicyFromConfig(toolset.Name, toolset.Lifecycle)) - if len(toolset.FileTypes) > 0 { - tool.SetFileTypes(toolset.FileTypes) - } - - return tool, nil -} - -func createUserPromptTool(_ context.Context, _ latest.Toolset, _ string, _ *config.RuntimeConfig, _ string) (tools.ToolSet, error) { - return userprompt.NewUserPromptTool(), nil -} - -func createOpenAPITool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { - expander := js.NewJsExpander(runConfig.EnvProvider()) - - specURL := expander.Expand(ctx, toolset.URL, nil) - headers := expander.ExpandMap(ctx, toolset.Headers) - - return openapi.NewOpenAPITool(specURL, headers), nil -} - -func createModelPickerTool(_ context.Context, toolset latest.Toolset, _ string, _ *config.RuntimeConfig, _ string) (tools.ToolSet, error) { - if len(toolset.Models) == 0 { - return nil, errors.New("model_picker toolset requires at least one model") - } - return modelpicker.NewModelPickerTool(toolset.Models), nil -} - -func createBackgroundAgentsTool(_ context.Context, _ latest.Toolset, _ string, _ *config.RuntimeConfig, _ string) (tools.ToolSet, error) { - return agenttool.NewToolSet(), nil -} - -func createRAGTool(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { - if toolset.RAGConfig == nil { - return nil, errors.New("rag toolset requires rag_config (should have been resolved from ref)") - } - - ragName := cmp.Or(toolset.Name, "rag") - - mgr, err := rag.NewManager(ctx, ragName, toolset.RAGConfig, rag.ManagersBuildConfig{ - ParentDir: parentDir, - ModelsGateway: runConfig.ModelsGateway, - Env: runConfig.EnvProvider(), - Models: runConfig.Models, - Providers: runConfig.Providers, - RuntimeConfig: runConfig, - }) - if err != nil { - return nil, fmt.Errorf("failed to create RAG manager: %w", err) - } - - toolName := cmp.Or(mgr.ToolName(), ragName) - return builtinrag.NewRAGTool(mgr, toolName), nil -} diff --git a/pkg/teamloader/registry_test.go b/pkg/teamloader/registry_test.go index 1db5050fb..a752135a0 100644 --- a/pkg/teamloader/registry_test.go +++ b/pkg/teamloader/registry_test.go @@ -73,90 +73,7 @@ func TestCreateMCPTool_BareCommandNotFound_CreatesToolsetAnyway(t *testing.T) { require.NoError(t, err) require.NotNil(t, tool) assert.Equal(t, "mcp(stdio cmd=some-nonexistent-mcp-binary)", tools.DescribeToolSet(tool)) -} - -func TestResolveToolsetWorkingDir(t *testing.T) { - t.Parallel() - - home, err := os.UserHomeDir() - require.NoError(t, err) - - tests := []struct { - name string - toolsetWorkingDir string - agentWorkingDir string - want string - }{ - { - name: "empty toolset dir returns agent dir", - toolsetWorkingDir: "", - agentWorkingDir: "/workspace", - want: "/workspace", - }, - { - name: "absolute toolset dir is returned as-is", - toolsetWorkingDir: "/tmp/mcp", - agentWorkingDir: "/workspace", - want: "/tmp/mcp", - }, - { - name: "relative toolset dir is joined with agent dir", - toolsetWorkingDir: "./backend", - agentWorkingDir: "/workspace", - want: "/workspace/backend", - }, - { - name: "bare relative dir joined with agent dir", - toolsetWorkingDir: "tools/mcp", - agentWorkingDir: "/workspace", - want: "/workspace/tools/mcp", - }, - { - name: "relative toolset dir with empty agent dir returns toolset dir unchanged", - toolsetWorkingDir: "./backend", - agentWorkingDir: "", - want: "./backend", - }, - { - name: "both empty returns empty", - toolsetWorkingDir: "", - agentWorkingDir: "", - want: "", - }, - // Tilde expansion tests (B2) - { - name: "tilde expands to home dir", - toolsetWorkingDir: "~/projects/app", - agentWorkingDir: "/workspace", - want: filepath.Join(home, "projects", "app"), - }, - { - name: "bare tilde expands to home dir", - toolsetWorkingDir: "~", - agentWorkingDir: "/workspace", - want: home, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - got := resolveToolsetWorkingDir(tt.toolsetWorkingDir, tt.agentWorkingDir) - assert.Equal(t, tt.want, got) - }) - } -} - -// TestResolveToolsetWorkingDir_EnvVarExpansion tests env-var expansion separately -// because t.Setenv is incompatible with t.Parallel on the parent test. -func TestResolveToolsetWorkingDir_EnvVarExpansion(t *testing.T) { - t.Setenv("TEST_REGISTRY_CWD_VAR", "/custom/path") - - got := resolveToolsetWorkingDir("${TEST_REGISTRY_CWD_VAR}/app", "/workspace") - assert.Equal(t, "/custom/path/app", got) -} - -// TestCreateMCPTool_WorkingDir_ReachesSubprocess verifies that working_dir is +} // TestCreateMCPTool_WorkingDir_ReachesSubprocess verifies that working_dir is // wired all the way through createMCPTool to the underlying stdio command (N5). func TestCreateMCPTool_WorkingDir_ReachesSubprocess(t *testing.T) { t.Setenv("DOCKER_AGENT_TOOLS_DIR", t.TempDir()) @@ -269,8 +186,8 @@ func TestCreateLSPTool_WorkingDir_ReachesHandler(t *testing.T) { require.NoError(t, err) require.NotNil(t, rawTool) - lspTool, ok := rawTool.(*lsp.Tool) - require.True(t, ok, "expected *lsp.Tool") + lspTool, ok := rawTool.(*lsp.ToolSet) + require.True(t, ok, "expected *lsp.ToolSet") assert.Equal(t, customDir, lspTool.WorkingDir()) } diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 9d4d5029c..ba74d6329 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -38,7 +38,7 @@ var defaultMaxTokens int64 = 32000 type loadOptions struct { modelOverrides []string promptFiles []string - toolsetRegistry *ToolsetRegistry + toolsetRegistry ToolsetRegistry } type Opt func(*loadOptions) error @@ -60,7 +60,7 @@ func WithPromptFiles(files []string) Opt { } // WithToolsetRegistry allows using a custom toolset registry instead of the default -func WithToolsetRegistry(registry *ToolsetRegistry) Opt { +func WithToolsetRegistry(registry ToolsetRegistry) Opt { return func(opts *loadOptions) error { opts.toolsetRegistry = registry return nil @@ -221,7 +221,7 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c loadedSkills := skills.Load(agentConfig.Skills.Sources) loadedSkills = filterSkillsByName(loadedSkills, agentConfig.Skills.Include) if len(loadedSkills) > 0 { - agentTools = append(agentTools, skillstool.NewSkillsToolset(loadedSkills, runConfig.WorkingDir)) + agentTools = append(agentTools, skillstool.New(loadedSkills, runConfig.WorkingDir)) } } @@ -407,14 +407,14 @@ func getFallbackModelsForAgent(ctx context.Context, cfg *latest.Config, a *lates // getToolsForAgent returns the tool definitions for an agent based on its // configuration. Toolset instructions support ${...} JavaScript placeholders // (e.g. ${env.X}); they are expanded here using the runtime env provider. -func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir string, runConfig *config.RuntimeConfig, registry *ToolsetRegistry, configName string, expander *js.Expander) ([]tools.ToolSet, []string) { +func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir string, runConfig *config.RuntimeConfig, registry ToolsetRegistry, configName string, expander *js.Expander) ([]tools.ToolSet, []string) { var ( toolSets []tools.ToolSet warnings []string lspBackends []lsp.Backend ) - deferredToolset := deferred.NewDeferredToolset() + deferredToolset := deferred.New() for i := range a.Toolsets { toolset := a.Toolsets[i] @@ -448,7 +448,7 @@ func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir stri // Instead of adding them individually (which causes duplicate tool names), // they are combined into a single Multiplexer after the loop. if toolset.Type == "lsp" { - if lspTool, ok := tool.(*lsp.Tool); ok { + if lspTool, ok := tool.(*lsp.ToolSet); ok { lspBackends = append(lspBackends, lsp.Backend{LSP: lspTool, Toolset: wrapped}) continue } @@ -472,10 +472,10 @@ func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir stri } if len(a.SubAgents) > 0 { - toolSets = append(toolSets, transfertask.NewTransferTaskTool()) + toolSets = append(toolSets, transfertask.New()) } if len(a.Handoffs) > 0 { - toolSets = append(toolSets, handoff.NewHandoffTool()) + toolSets = append(toolSets, handoff.New()) } // Wrap all tools in a single Code Mode toolset. diff --git a/pkg/teamloader/teamloader_test.go b/pkg/teamloader/teamloader_test.go index 60dca78f8..6381ca5a5 100644 --- a/pkg/teamloader/teamloader_test.go +++ b/pkg/teamloader/teamloader_test.go @@ -69,7 +69,7 @@ func TestGetToolsForAgent_ContinuesOnCreateToolError(t *testing.T) { expander := js.NewJsExpander(runConfig.EnvProvider()) - got, warnings := getToolsForAgent(t.Context(), a, ".", &runConfig, NewToolsetRegistry(), "test-config", expander) + got, warnings := getToolsForAgent(t.Context(), a, ".", &runConfig, &toolsetRegistry{}, "test-config", expander) require.Empty(t, got) require.NotEmpty(t, warnings) diff --git a/pkg/tools/a2a/a2a.go b/pkg/tools/a2a/a2a.go index 65c91b439..8f5f671c4 100644 --- a/pkg/tools/a2a/a2a.go +++ b/pkg/tools/a2a/a2a.go @@ -15,7 +15,10 @@ import ( "github.com/a2aproject/a2a-go/a2aclient" "github.com/a2aproject/a2a-go/a2aclient/agentcard" + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/httpclient" + "github.com/docker/docker-agent/pkg/js" "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/upstream" ) @@ -37,6 +40,13 @@ var ( _ tools.Instructable = (*Toolset)(nil) ) +// CreateToolSet is used by the tools registry. +func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { + expander := js.NewJsExpander(runConfig.EnvProvider()) + headers := expander.ExpandMap(ctx, toolset.Headers) + return NewToolset(toolset.Name, toolset.URL, headers), nil +} + // NewToolset creates a new A2A toolset for the given URL. func NewToolset(name, url string, headers map[string]string) *Toolset { return &Toolset{ diff --git a/pkg/tools/builtin/agent/agent.go b/pkg/tools/builtin/agent/agent.go index 01da68fbe..109be74b9 100644 --- a/pkg/tools/builtin/agent/agent.go +++ b/pkg/tools/builtin/agent/agent.go @@ -34,6 +34,11 @@ const ( maxOutputBytes = 10 * 1024 * 1024 // 10 MB ) +// CreateToolSet is used by the tools registry. +func CreateToolSet() (tools.ToolSet, error) { + return New(), nil +} + // RunBackgroundAgentArgs specifies the parameters for dispatching a sub-agent task asynchronously. type RunBackgroundAgentArgs struct { Agent string `json:"agent" jsonschema:"The name of the sub-agent to run in the background."` @@ -427,21 +432,21 @@ func (h *Handler) RegisterHandlers(register func(name string, fn func(context.Co register(ToolNameStopBackgroundAgent, h.HandleStop) } -// NewToolSet returns a lightweight ToolSet for registering background agent +// New returns a lightweight ToolSet for registering background agent // tool definitions and instructions. It does not require a Runner and is // suitable for use in the teamloader registry. -func NewToolSet() tools.ToolSet { - return &toolSet{} +func New() tools.ToolSet { + return &ToolSet{} } -// toolSet provides tool definitions and instructions without a Runner. -type toolSet struct{} +// ToolSet provides tool definitions and instructions without a Runner. +type ToolSet struct{} -func (t *toolSet) Tools(context.Context) ([]tools.Tool, error) { +func (t *ToolSet) Tools(context.Context) ([]tools.Tool, error) { return backgroundAgentTools(), nil } -func (t *toolSet) Instructions() string { +func (t *ToolSet) Instructions() string { return `# Background Agent Tasks Use background agent tasks to dispatch work to sub-agents concurrently. diff --git a/pkg/tools/builtin/agent/agent_test.go b/pkg/tools/builtin/agent/agent_test.go index 492bfd312..2b64d319d 100644 --- a/pkg/tools/builtin/agent/agent_test.go +++ b/pkg/tools/builtin/agent/agent_test.go @@ -592,7 +592,7 @@ func TestHandler_ConcurrentAccess(t *testing.T) { // --- Tools --- func TestNewToolSet_ReturnsFourTools(t *testing.T) { - ts := NewToolSet() + ts := New() toolsList, err := ts.Tools(t.Context()) require.NoError(t, err) assert.Len(t, toolsList, 4) @@ -608,7 +608,7 @@ func TestNewToolSet_ReturnsFourTools(t *testing.T) { } func TestNewToolSet_Instructions(t *testing.T) { - ts := NewToolSet() + ts := New() instructable, ok := ts.(tools.Instructable) require.True(t, ok, "NewToolSet should implement Instructable") diff --git a/pkg/tools/builtin/api/api.go b/pkg/tools/builtin/api/api.go index 56c3239a7..6032f084c 100644 --- a/pkg/tools/builtin/api/api.go +++ b/pkg/tools/builtin/api/api.go @@ -12,6 +12,7 @@ import ( "net/url" "time" + "github.com/docker/docker-agent/pkg/config" "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/httpclient" "github.com/docker/docker-agent/pkg/js" @@ -20,7 +21,7 @@ import ( "github.com/docker/docker-agent/pkg/useragent" ) -type Tool struct { +type ToolSet struct { config latest.APIToolConfig expander *js.Expander @@ -31,11 +32,11 @@ type Tool struct { // Verify interface compliance var ( - _ tools.ToolSet = (*Tool)(nil) - _ tools.Instructable = (*Tool)(nil) + _ tools.ToolSet = (*ToolSet)(nil) + _ tools.Instructable = (*ToolSet)(nil) ) -func (t *Tool) callTool(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { +func (t *ToolSet) callTool(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { client := httpclient.NewSafeClient(30*time.Second, t.unsafe) endpoint := t.config.Endpoint @@ -89,18 +90,31 @@ func (t *Tool) callTool(ctx context.Context, toolCall tools.ToolCall) (*tools.To return tools.ResultSuccess(limitOutput(string(body))), nil } -func NewAPITool(config latest.APIToolConfig, expander *js.Expander) *Tool { - return &Tool{ - config: config, +// CreateToolSet is used by the tools registry. +func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { + if toolset.APIConfig.Endpoint == "" { + return nil, errors.New("api tool requires an endpoint in api_config") + } + + expander := js.NewJsExpander(runConfig.EnvProvider()) + toolset.APIConfig.Endpoint = expander.Expand(ctx, toolset.APIConfig.Endpoint, nil) + toolset.APIConfig.Headers = expander.ExpandMap(ctx, toolset.APIConfig.Headers) + + return New(toolset.APIConfig, expander), nil +} + +func New(apiConfig latest.APIToolConfig, expander *js.Expander) *ToolSet { + return &ToolSet{ + config: apiConfig, expander: expander, } } -func (t *Tool) Instructions() string { +func (t *ToolSet) Instructions() string { return t.config.Instruction } -func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { +func (t *ToolSet) Tools(context.Context) ([]tools.Tool, error) { inputSchema, err := tools.SchemaToMap(map[string]any{ "type": "object", "properties": t.config.Args, diff --git a/pkg/tools/builtin/api/api_test.go b/pkg/tools/builtin/api/api_test.go index bfe51a02b..6fb4bb786 100644 --- a/pkg/tools/builtin/api/api_test.go +++ b/pkg/tools/builtin/api/api_test.go @@ -21,9 +21,9 @@ import ( // newAPIToolForTest constructs an APITool that bypasses SSRF dial-time // protection so tests can talk to httptest.NewServer (which binds to // 127.0.0.1). It is defined in a *_test.go file so it is not compiled -// into release binaries. Production callers must use [NewAPITool]. -func newAPIToolForTest(config latest.APIToolConfig, expander *js.Expander) *Tool { - t := NewAPITool(config, expander) +// into release binaries. Production callers must use [New]. +func newAPIToolForTest(config latest.APIToolConfig, expander *js.Expander) *ToolSet { + t := New(config, expander) t.unsafe = true return t } @@ -158,7 +158,7 @@ func TestAPITool_IdentityHeaders(t *testing.T) { func TestAPITool_DefaultOutputSchema(t *testing.T) { t.Parallel() - tool := NewAPITool(latest.APIToolConfig{ + tool := New(latest.APIToolConfig{ Name: "default-schema", Method: http.MethodGet, Endpoint: "https://example.com/api", @@ -184,7 +184,7 @@ func TestAPITool_CustomOutputSchema(t *testing.T) { "required": []string{"first_name", "last_name"}, } - tool := NewAPITool(latest.APIToolConfig{ + tool := New(latest.APIToolConfig{ Name: "user-info", Method: http.MethodGet, Endpoint: "https://example.com/api/users/${id}", @@ -223,7 +223,7 @@ func TestAPITool_RejectsLocalAddresses(t *testing.T) { for _, target := range tests { t.Run(target, func(t *testing.T) { t.Parallel() - tool := NewAPITool(latest.APIToolConfig{ + tool := New(latest.APIToolConfig{ Method: http.MethodGet, Endpoint: target, }, testExpander()) diff --git a/pkg/tools/builtin/deferred/deferred.go b/pkg/tools/builtin/deferred/deferred.go index e18b354fb..2cfd4ed97 100644 --- a/pkg/tools/builtin/deferred/deferred.go +++ b/pkg/tools/builtin/deferred/deferred.go @@ -21,7 +21,7 @@ type deferredToolEntry struct { source tools.ToolSet } -type Toolset struct { +type ToolSet struct { mu sync.RWMutex deferredTools map[string]deferredToolEntry activatedTools map[string]tools.Tool @@ -30,9 +30,9 @@ type Toolset struct { // Verify interface compliance var ( - _ tools.ToolSet = (*Toolset)(nil) - _ tools.Startable = (*Toolset)(nil) - _ tools.Instructable = (*Toolset)(nil) + _ tools.ToolSet = (*ToolSet)(nil) + _ tools.Startable = (*ToolSet)(nil) + _ tools.Instructable = (*ToolSet)(nil) ) type deferredSource struct { @@ -41,14 +41,14 @@ type deferredSource struct { tools []string } -func NewDeferredToolset() *Toolset { - return &Toolset{ +func New() *ToolSet { + return &ToolSet{ deferredTools: make(map[string]deferredToolEntry), activatedTools: make(map[string]tools.Tool), } } -func (d *Toolset) AddSource(toolset tools.ToolSet, deferAll bool, toolNames []string) { +func (d *ToolSet) AddSource(toolset tools.ToolSet, deferAll bool, toolNames []string) { d.mu.Lock() defer d.mu.Unlock() @@ -59,13 +59,13 @@ func (d *Toolset) AddSource(toolset tools.ToolSet, deferAll bool, toolNames []st }) } -func (d *Toolset) HasSources() bool { +func (d *ToolSet) HasSources() bool { d.mu.RLock() defer d.mu.RUnlock() return len(d.sources) > 0 } -func (d *Toolset) Instructions() string { +func (d *ToolSet) Instructions() string { return `## Deferred Tools Use search_tool to discover additional tools by keyword (e.g., "remote", "read", "write"). Use add_tool to activate a discovered tool.` @@ -84,7 +84,7 @@ type AddToolArgs struct { Name string `json:"name" jsonschema:"The name of the tool to activate"` } -func (d *Toolset) handleSearchTool(_ context.Context, args SearchToolArgs) (*tools.ToolCallResult, error) { +func (d *ToolSet) handleSearchTool(_ context.Context, args SearchToolArgs) (*tools.ToolCallResult, error) { query := strings.ToLower(args.Query) d.mu.RLock() @@ -115,7 +115,7 @@ func (d *Toolset) handleSearchTool(_ context.Context, args SearchToolArgs) (*too return tools.ResultSuccess(fmt.Sprintf("Found %d deferred tool(s):\n%s", len(results), string(output))), nil } -func (d *Toolset) handleAddTool(_ context.Context, args AddToolArgs) (*tools.ToolCallResult, error) { +func (d *ToolSet) handleAddTool(_ context.Context, args AddToolArgs) (*tools.ToolCallResult, error) { d.mu.Lock() defer d.mu.Unlock() @@ -134,7 +134,7 @@ func (d *Toolset) handleAddTool(_ context.Context, args AddToolArgs) (*tools.Too return tools.ResultSuccess(fmt.Sprintf("Tool '%s' has been activated and is now available for use.\n\nDescription: %s", args.Name, entry.tool.Description)), nil } -func (d *Toolset) Tools(context.Context) ([]tools.Tool, error) { +func (d *ToolSet) Tools(context.Context) ([]tools.Tool, error) { d.mu.RLock() defer d.mu.RUnlock() @@ -172,7 +172,7 @@ func (d *Toolset) Tools(context.Context) ([]tools.Tool, error) { return result, nil } -func (d *Toolset) Start(ctx context.Context) error { +func (d *ToolSet) Start(ctx context.Context) error { // Note: we are not responsible for starting the underlying toolsets here d.mu.RLock() defer d.mu.RUnlock() @@ -200,6 +200,6 @@ func (d *Toolset) Start(ctx context.Context) error { return nil } -func (d *Toolset) Stop(context.Context) error { +func (d *ToolSet) Stop(context.Context) error { return nil } diff --git a/pkg/tools/builtin/deferred/deferred_test.go b/pkg/tools/builtin/deferred/deferred_test.go index b6224a1e0..d053e0c03 100644 --- a/pkg/tools/builtin/deferred/deferred_test.go +++ b/pkg/tools/builtin/deferred/deferred_test.go @@ -28,7 +28,7 @@ func TestDeferredToolset_SearchTool(t *testing.T) { }, } - dt := NewDeferredToolset() + dt := New() dt.AddSource(mockTools, true, nil) err := dt.Start(ctx) require.NoError(t, err) @@ -65,7 +65,7 @@ func TestDeferredToolset_AddTool(t *testing.T) { }, } - dt := NewDeferredToolset() + dt := New() dt.AddSource(mockTools, true, nil) err := dt.Start(ctx) require.NoError(t, err) diff --git a/pkg/tools/builtin/fetch/fetch.go b/pkg/tools/builtin/fetch/fetch.go index 3cbd0b379..ab48fb747 100644 --- a/pkg/tools/builtin/fetch/fetch.go +++ b/pkg/tools/builtin/fetch/fetch.go @@ -16,7 +16,10 @@ import ( "github.com/k3a/html2text" "github.com/temoto/robotstxt" + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/httpclient" + "github.com/docker/docker-agent/pkg/js" "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/useragent" ) @@ -25,14 +28,14 @@ const ( ToolNameFetch = "fetch" ) -type Tool struct { +type ToolSet struct { handler *fetchHandler } // Verify interface compliance var ( - _ tools.ToolSet = (*Tool)(nil) - _ tools.Instructable = (*Tool)(nil) + _ tools.ToolSet = (*ToolSet)(nil) + _ tools.Instructable = (*ToolSet)(nil) ) type fetchHandler struct { @@ -452,8 +455,30 @@ func htmlToText(html string) string { return html2text.HTML2Text(html) } -func NewFetchTool(options ...ToolOption) *Tool { - tool := &Tool{ +// CreateToolSet is used by the tools registry. +func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { + expander := js.NewJsExpander(runConfig.EnvProvider()) + + var opts []ToolOption + if toolset.Timeout > 0 { + timeout := time.Duration(toolset.Timeout) * time.Second + opts = append(opts, WithTimeout(timeout)) + } + if len(toolset.AllowedDomains) > 0 { + opts = append(opts, WithAllowedDomains(toolset.AllowedDomains)) + } + if len(toolset.BlockedDomains) > 0 { + opts = append(opts, WithBlockedDomains(toolset.BlockedDomains)) + } + if toolset.AllowPrivateIPs { + opts = append(opts, WithAllowPrivateIPs(true)) + } + opts = append(opts, WithHeaders(expander.ExpandMap(ctx, toolset.Headers))) + return New(opts...), nil +} + +func New(options ...ToolOption) *ToolSet { + tool := &ToolSet{ handler: &fetchHandler{ timeout: 30 * time.Second, }, @@ -466,10 +491,10 @@ func NewFetchTool(options ...ToolOption) *Tool { return tool } -type ToolOption func(*Tool) +type ToolOption func(*ToolSet) func WithTimeout(timeout time.Duration) ToolOption { - return func(t *Tool) { + return func(t *ToolSet) { t.handler.timeout = timeout } } @@ -478,7 +503,7 @@ func WithTimeout(timeout time.Duration) ToolOption { // of the supplied domain patterns. See matchesDomain for matching rules. // An empty or nil slice disables the allow-list (every host is allowed). func WithAllowedDomains(domains []string) ToolOption { - return func(t *Tool) { + return func(t *ToolSet) { t.handler.allowedDomains = domains } } @@ -487,7 +512,7 @@ func WithAllowedDomains(domains []string) ToolOption { // matches one of the supplied domain patterns. See matchesDomain for matching // rules. An empty or nil slice disables the deny-list. func WithBlockedDomains(domains []string) ToolOption { - return func(t *Tool) { + return func(t *ToolSet) { t.handler.blockedDomains = domains } } @@ -499,7 +524,7 @@ func WithBlockedDomains(domains []string) ToolOption { // so DNS rebinding cannot bypass the check. Set to true only when an // agent legitimately needs to reach internal services. func WithAllowPrivateIPs(allow bool) ToolOption { - return func(t *Tool) { + return func(t *ToolSet) { t.handler.allowPrivateIPs = allow } } @@ -509,12 +534,12 @@ func WithAllowPrivateIPs(allow bool) ToolOption { // These are applied last, so they override the default User-Agent and the // format-driven Accept header. An empty or nil map is a no-op. func WithHeaders(headers map[string]string) ToolOption { - return func(t *Tool) { + return func(t *ToolSet) { t.handler.headers = headers } } -func (t *Tool) Instructions() string { +func (t *ToolSet) Instructions() string { var b strings.Builder b.WriteString("## Fetch Tool\n\nFetch content from HTTP/HTTPS URLs. Supports multiple URLs per call, output format selection (text, markdown, html), and respects robots.txt.") if d := t.handler.allowedDomains; len(d) > 0 { @@ -526,7 +551,7 @@ func (t *Tool) Instructions() string { return b.String() } -func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { +func (t *ToolSet) Tools(context.Context) ([]tools.Tool, error) { return []tools.Tool{ { Name: ToolNameFetch, diff --git a/pkg/tools/builtin/fetch/fetch_test.go b/pkg/tools/builtin/fetch/fetch_test.go index 0909723fa..7627b6148 100644 --- a/pkg/tools/builtin/fetch/fetch_test.go +++ b/pkg/tools/builtin/fetch/fetch_test.go @@ -20,14 +20,14 @@ import ( // newFetchToolForTest constructs a FetchTool that bypasses SSRF dial-time // protection so tests can talk to httptest.NewServer (which binds to // 127.0.0.1). It is defined in a *_test.go file so it is not compiled -// into release binaries. Production callers must use [NewFetchTool], +// into release binaries. Production callers must use [New], // which refuses non-public addresses by default. // // The helper prepends [WithAllowPrivateIPs](true) to opts so explicit // caller options still take precedence (a later option overrides an // earlier one). -func newFetchToolForTest(opts ...ToolOption) *Tool { - return NewFetchTool(append([]ToolOption{WithAllowPrivateIPs(true)}, opts...)...) +func newFetchToolForTest(opts ...ToolOption) *ToolSet { + return New(append([]ToolOption{WithAllowPrivateIPs(true)}, opts...)...) } func TestFetchToolWithOptions(t *testing.T) { @@ -860,7 +860,7 @@ func TestFetch_Headers_StrippedOnRobotsCrossHostRedirect(t *testing.T) { } // TestFetch_DefaultRefusesNonPublicAddresses pins the security-relevant -// default: an out-of-the-box [NewFetchTool] (no WithAllowPrivateIPs) +// default: an out-of-the-box [New] (no WithAllowPrivateIPs) // must refuse to dial loopback, RFC1918, link-local incl. cloud // metadata, multicast and the unspecified address. These are checked // after DNS resolution, so a public hostname that resolves to a private @@ -878,7 +878,7 @@ func TestFetch_DefaultRefusesNonPublicAddresses(t *testing.T) { for _, target := range tests { t.Run(target, func(t *testing.T) { t.Parallel() - tool := NewFetchTool() // production constructor, no opt-in + tool := New() // production constructor, no opt-in result, err := tool.handler.CallTool(t.Context(), ToolArgs{ URLs: []string{target}, Format: "text", @@ -902,7 +902,7 @@ func TestFetch_AllowPrivateIPsRestoresLegacyBehaviour(t *testing.T) { })) t.Cleanup(server.Close) - tool := NewFetchTool(WithAllowPrivateIPs(true)) + tool := New(WithAllowPrivateIPs(true)) result, err := tool.handler.CallTool(t.Context(), ToolArgs{ URLs: []string{server.URL}, Format: "text", diff --git a/pkg/tools/builtin/filesystem/filesystem.go b/pkg/tools/builtin/filesystem/filesystem.go index 2f9608a23..95293737d 100644 --- a/pkg/tools/builtin/filesystem/filesystem.go +++ b/pkg/tools/builtin/filesystem/filesystem.go @@ -16,6 +16,8 @@ import ( "sync" "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/fsx" pathx "github.com/docker/docker-agent/pkg/path" "github.com/docker/docker-agent/pkg/tools" @@ -39,7 +41,7 @@ type PostEditConfig struct { Cmd string // Command to execute (with $path placeholder) } -type Tool struct { +type ToolSet struct { workingDir string postEditCommands []PostEditConfig ignoreVCS bool @@ -62,21 +64,21 @@ type Tool struct { // Verify interface compliance var ( - _ tools.ToolSet = (*Tool)(nil) - _ tools.Instructable = (*Tool)(nil) - _ io.Closer = (*Tool)(nil) + _ tools.ToolSet = (*ToolSet)(nil) + _ tools.Instructable = (*ToolSet)(nil) + _ io.Closer = (*ToolSet)(nil) ) -type Opt func(*Tool) +type Opt func(*ToolSet) func WithPostEditCommands(postEditCommands []PostEditConfig) Opt { - return func(t *Tool) { + return func(t *ToolSet) { t.postEditCommands = postEditCommands } } func WithIgnoreVCS(ignoreVCS bool) Opt { - return func(t *Tool) { + return func(t *ToolSet) { t.ignoreVCS = ignoreVCS } } @@ -96,7 +98,7 @@ func WithIgnoreVCS(ignoreVCS bool) Opt { // Invalid entries (e.g. an empty string) are logged and the allow-list is // silently dropped, mirroring how WithIgnoreVCS handles construction errors. func WithAllowList(roots []string) Opt { - return func(t *Tool) { + return func(t *ToolSet) { set, err := newPathRootSet(t.workingDir, roots) if err != nil { slog.Error("filesystem allow-list: invalid entry; disabling toolset", "error", err) @@ -113,7 +115,7 @@ func WithAllowList(roots []string) Opt { // path that matches both is rejected. An empty or nil slice disables the // deny-list. func WithDenyList(roots []string) Opt { - return func(t *Tool) { + return func(t *ToolSet) { set, err := newPathRootSet(t.workingDir, roots) if err != nil { slog.Error("filesystem deny-list: invalid entry; disabling toolset", "error", err) @@ -124,8 +126,48 @@ func WithDenyList(roots []string) Opt { } } -func NewFilesystemTool(workingDir string, opts ...Opt) *Tool { - t := &Tool{ +// CreateToolSet is used by the tools registry. +func CreateToolSet(toolset latest.Toolset, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { + wd := runConfig.WorkingDir + if wd == "" { + var err error + wd, err = os.Getwd() + if err != nil { + return nil, fmt.Errorf("failed to get working directory: %w", err) + } + } + + var opts []Opt + + ignoreVCS := true + if toolset.IgnoreVCS != nil { + ignoreVCS = *toolset.IgnoreVCS + } + opts = append(opts, WithIgnoreVCS(ignoreVCS)) + + if len(toolset.AllowList) > 0 { + opts = append(opts, WithAllowList(toolset.AllowList)) + } + if len(toolset.DenyList) > 0 { + opts = append(opts, WithDenyList(toolset.DenyList)) + } + + if len(toolset.PostEdit) > 0 { + postEditConfigs := make([]PostEditConfig, len(toolset.PostEdit)) + for i, pe := range toolset.PostEdit { + postEditConfigs[i] = PostEditConfig{ + Path: pe.Path, + Cmd: pe.Cmd, + } + } + opts = append(opts, WithPostEditCommands(postEditConfigs)) + } + + return New(wd, opts...), nil +} + +func New(workingDir string, opts ...Opt) *ToolSet { + t := &ToolSet{ workingDir: workingDir, } @@ -138,7 +180,7 @@ func NewFilesystemTool(workingDir string, opts ...Opt) *Tool { // Close releases any *os.Root file descriptors held by the allow/deny lists. // It is safe to call Close multiple times. -func (t *Tool) Close() error { +func (t *ToolSet) Close() error { if t.allowList != nil { t.allowList.close() } @@ -148,7 +190,7 @@ func (t *Tool) Close() error { return nil } -func (t *Tool) Instructions() string { +func (t *ToolSet) Instructions() string { var b strings.Builder b.WriteString(`## Filesystem Tools @@ -345,7 +387,7 @@ func tryRepairEditFileJSON(data []byte) ([]byte, bool) { return nil, false } -func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { +func (t *ToolSet) Tools(context.Context) ([]tools.Tool, error) { return []tools.Tool{ { Name: ToolNameDirectoryTree, @@ -483,7 +525,7 @@ func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { } // executePostEditCommands executes any matching post-edit commands for the given file path -func (t *Tool) executePostEditCommands(ctx context.Context, filePath string) error { +func (t *ToolSet) executePostEditCommands(ctx context.Context, filePath string) error { if len(t.postEditCommands) == 0 { return nil } @@ -500,7 +542,7 @@ func (t *Tool) executePostEditCommands(ctx context.Context, filePath string) err // resolvePath does NOT enforce the allow- or deny-lists; callers should use // [resolveAndCheckPath] when those checks are required (i.e. for any path // that originates from a tool argument). -func (t *Tool) resolvePath(path string) string { +func (t *ToolSet) resolvePath(path string) string { if expandedPath, err := pathx.ExpandHomeDir(path); err == nil { path = expandedPath } @@ -550,7 +592,7 @@ func (t *Tool) resolvePath(path string) string { // the allow-list root. // - On GOOS=windows, [*os.Root] additionally rejects reserved device // names (NUL, COM1, …), which is a strengthening, not a weakening. -func (t *Tool) resolveAndCheckPath(path string) (string, error) { +func (t *ToolSet) resolveAndCheckPath(path string) (string, error) { if t.sandboxBroken { return "", errors.New("filesystem toolset is disabled due to invalid allow/deny list configuration") } @@ -586,7 +628,7 @@ func (t *Tool) resolveAndCheckPath(path string) (string, error) { // - Path no longer inside any entry (e.g. a symlink swap moved the real // target out between the static check and the I/O) → (nil, "", err): // callers MUST refuse; falling back to os.* would follow the symlink. -func (t *Tool) rootedAccess(resolved string) (*os.Root, string, error) { +func (t *ToolSet) rootedAccess(resolved string) (*os.Root, string, error) { if t.allowList == nil { return nil, "", nil } @@ -605,7 +647,7 @@ func (t *Tool) rootedAccess(resolved string) (*os.Root, string, error) { // allow-list contains. When no rooted access is available it falls back to // the plain [os.ReadFile]. Callers MUST pass a path that has already been // validated by [resolveAndCheckPath]. -func (t *Tool) readFile(resolved string) ([]byte, error) { +func (t *ToolSet) readFile(resolved string) ([]byte, error) { root, rel, err := t.rootedAccess(resolved) if err != nil { return nil, err @@ -620,7 +662,7 @@ func (t *Tool) readFile(resolved string) ([]byte, error) { // for the contract. The call is rejected by the kernel when any component // of rel is an out-of-root symlink, so an attacker cannot win the swap // race between the [resolveAndCheckPath] check and the write. -func (t *Tool) writeFile(resolved string, data []byte, perm os.FileMode) error { +func (t *ToolSet) writeFile(resolved string, data []byte, perm os.FileMode) error { root, rel, err := t.rootedAccess(resolved) if err != nil { return err @@ -633,7 +675,7 @@ func (t *Tool) writeFile(resolved string, data []byte, perm os.FileMode) error { // stat is a TOCTOU-safe equivalent of [os.Stat]. See [readFile] for the // contract. -func (t *Tool) stat(resolved string) (os.FileInfo, error) { +func (t *ToolSet) stat(resolved string) (os.FileInfo, error) { root, rel, err := t.rootedAccess(resolved) if err != nil { return nil, err @@ -647,7 +689,7 @@ func (t *Tool) stat(resolved string) (os.FileInfo, error) { // mkdirAll is a TOCTOU-safe equivalent of [os.MkdirAll]. See [readFile] // for the contract. A rooted MkdirAll on "." is a no-op (the root already // exists by construction). -func (t *Tool) mkdirAll(resolved string, perm os.FileMode) error { +func (t *ToolSet) mkdirAll(resolved string, perm os.FileMode) error { root, rel, err := t.rootedAccess(resolved) if err != nil { return err @@ -664,7 +706,7 @@ func (t *Tool) mkdirAll(resolved string, perm os.FileMode) error { // readDir is a TOCTOU-safe equivalent of [os.ReadDir]. See [readFile] // for the contract. We use [*os.Root].Open + [*os.File].ReadDir because // [*os.Root] does not expose ReadDir directly. -func (t *Tool) readDir(resolved string) ([]os.DirEntry, error) { +func (t *ToolSet) readDir(resolved string) ([]os.DirEntry, error) { root, rel, err := t.rootedAccess(resolved) if err != nil { return nil, err @@ -684,7 +726,7 @@ func (t *Tool) readDir(resolved string) ([]os.DirEntry, error) { // is available we use [*os.Root].Remove, which only unlinks the named // directory entry and refuses to follow a trailing symlink that escapes // the root. Otherwise we fall back to the platform-specific [rmdir]. -func (t *Tool) removeDir(resolved string) error { +func (t *ToolSet) removeDir(resolved string) error { root, rel, err := t.rootedAccess(resolved) if err != nil { return err @@ -697,7 +739,7 @@ func (t *Tool) removeDir(resolved string) error { // initGitignoreMatcher initializes the gitignore matcher for the working directory. // It is safe to call multiple times; initialization only happens once. -func (t *Tool) initGitignoreMatcher() { +func (t *ToolSet) initGitignoreMatcher() { if !t.ignoreVCS { return } @@ -720,7 +762,7 @@ func (t *Tool) initGitignoreMatcher() { } // shouldIgnorePath checks if a path should be ignored based on VCS rules -func (t *Tool) shouldIgnorePath(path string) bool { +func (t *ToolSet) shouldIgnorePath(path string) bool { if !t.ignoreVCS { return false } @@ -739,7 +781,7 @@ func (t *Tool) shouldIgnorePath(path string) bool { // Handler implementations -func (t *Tool) handleDirectoryTree(ctx context.Context, args DirectoryTreeArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) handleDirectoryTree(ctx context.Context, args DirectoryTreeArgs) (*tools.ToolCallResult, error) { resolvedPath, err := t.resolveAndCheckPath(args.Path) if err != nil { return tools.ResultError(err.Error()), nil @@ -797,7 +839,7 @@ func countTreeNodes(node *fsx.TreeNode) (files, dirs int) { // repair logic for malformed JSON, then delegates to handleEditFile. // This bypasses tools.NewHandler because Go's json.Unmarshal scanner rejects // structurally invalid JSON before calling any custom UnmarshalJSON method. -func (t *Tool) editFileHandler() tools.ToolHandler { +func (t *ToolSet) editFileHandler() tools.ToolHandler { return func(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { data := toolCall.Function.Arguments if data == "" { @@ -811,7 +853,7 @@ func (t *Tool) editFileHandler() tools.ToolHandler { } } -func (t *Tool) handleEditFile(ctx context.Context, args EditFileArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) handleEditFile(ctx context.Context, args EditFileArgs) (*tools.ToolCallResult, error) { resolvedPath, err := t.resolveAndCheckPath(args.Path) if err != nil { return tools.ResultError(err.Error()), nil @@ -849,7 +891,7 @@ func (t *Tool) handleEditFile(ctx context.Context, args EditFileArgs) (*tools.To return tools.ResultSuccess("File edited successfully. Changes:\n" + strings.Join(changes, "\n")), nil } -func (t *Tool) handleListDirectory(_ context.Context, args ListDirectoryArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) handleListDirectory(_ context.Context, args ListDirectoryArgs) (*tools.ToolCallResult, error) { resolvedPath, err := t.resolveAndCheckPath(args.Path) if err != nil { return tools.ResultError(err.Error()), nil @@ -890,7 +932,7 @@ func (t *Tool) handleListDirectory(_ context.Context, args ListDirectoryArgs) (* }, nil } -func (t *Tool) handleReadFile(_ context.Context, args ReadFileArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) handleReadFile(_ context.Context, args ReadFileArgs) (*tools.ToolCallResult, error) { resolvedPath, err := t.resolveAndCheckPath(args.Path) if err != nil { return &tools.ToolCallResult{ @@ -947,7 +989,7 @@ func (t *Tool) handleReadFile(_ context.Context, args ReadFileArgs) (*tools.Tool // readImageFile reads an image file and returns it as base64-encoded image content. // The caller must ensure the file exists (e.g. via os.Stat) before calling this method. -func (t *Tool) readImageFile(resolvedPath, originalPath string) (*tools.ToolCallResult, error) { +func (t *ToolSet) readImageFile(resolvedPath, originalPath string) (*tools.ToolCallResult, error) { data, err := t.readFile(resolvedPath) if err != nil { errMsg := err.Error() @@ -996,7 +1038,7 @@ func (t *Tool) readImageFile(resolvedPath, originalPath string) (*tools.ToolCall }, nil } -func (t *Tool) handleReadMultipleFiles(ctx context.Context, args ReadMultipleFilesArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) handleReadMultipleFiles(ctx context.Context, args ReadMultipleFilesArgs) (*tools.ToolCallResult, error) { type PathContent struct { Path string `json:"path"` Content string `json:"content"` @@ -1070,7 +1112,7 @@ func (t *Tool) handleReadMultipleFiles(ctx context.Context, args ReadMultipleFil }, nil } -func (t *Tool) handleSearchFilesContent(_ context.Context, args SearchFilesContentArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) handleSearchFilesContent(_ context.Context, args SearchFilesContentArgs) (*tools.ToolCallResult, error) { resolvedPath, err := t.resolveAndCheckPath(args.Path) if err != nil { return tools.ResultError(err.Error()), nil @@ -1189,7 +1231,7 @@ func (t *Tool) handleSearchFilesContent(_ context.Context, args SearchFilesConte }, nil } -func (t *Tool) handleWriteFile(ctx context.Context, args WriteFileArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) handleWriteFile(ctx context.Context, args WriteFileArgs) (*tools.ToolCallResult, error) { resolvedPath, err := t.resolveAndCheckPath(args.Path) if err != nil { return tools.ResultError(err.Error()), nil @@ -1212,7 +1254,7 @@ func (t *Tool) handleWriteFile(ctx context.Context, args WriteFileArgs) (*tools. return tools.ResultSuccess(fmt.Sprintf("File written successfully: %s (%d bytes)", args.Path, len(args.Content))), nil } -func (t *Tool) handleCreateDirectory(_ context.Context, args CreateDirectoryArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) handleCreateDirectory(_ context.Context, args CreateDirectoryArgs) (*tools.ToolCallResult, error) { var results []string for _, path := range args.Paths { resolvedPath, err := t.resolveAndCheckPath(path) @@ -1228,7 +1270,7 @@ func (t *Tool) handleCreateDirectory(_ context.Context, args CreateDirectoryArgs return tools.ResultSuccess(strings.Join(results, "\n")), nil } -func (t *Tool) handleRemoveDirectory(_ context.Context, args RemoveDirectoryArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) handleRemoveDirectory(_ context.Context, args RemoveDirectoryArgs) (*tools.ToolCallResult, error) { var results []string for _, path := range args.Paths { resolvedPath, err := t.resolveAndCheckPath(path) diff --git a/pkg/tools/builtin/filesystem/filesystem_paths_test.go b/pkg/tools/builtin/filesystem/filesystem_paths_test.go index d5ce847b7..164ac0aa5 100644 --- a/pkg/tools/builtin/filesystem/filesystem_paths_test.go +++ b/pkg/tools/builtin/filesystem/filesystem_paths_test.go @@ -21,7 +21,7 @@ func resetHomeDir(t *testing.T, dir string) { func TestFilesystemTool_DefaultIsUnrestricted(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) // No allow_list, no deny_list: everything resolvable goes through. resolved, err := tool.resolveAndCheckPath("/etc/hosts") @@ -38,7 +38,7 @@ func TestFilesystemTool_DefaultIsUnrestricted(t *testing.T) { func TestFilesystemTool_AllowList_DotMeansWorkingDir(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir, WithAllowList([]string{"."})) + tool := New(tmpDir, WithAllowList([]string{"."})) // Inside working dir is fine. _, err := tool.resolveAndCheckPath("file.txt") @@ -62,7 +62,7 @@ func TestFilesystemTool_AllowList_TildeMeansHome(t *testing.T) { resetHomeDir(t, homeDir) wd := t.TempDir() - tool := NewFilesystemTool(wd, WithAllowList([]string{"~"})) + tool := New(wd, WithAllowList([]string{"~"})) // A path under $HOME is allowed via ~/... resolved, err := tool.resolveAndCheckPath(filepath.Join(homeDir, "doc.md")) @@ -81,7 +81,7 @@ func TestFilesystemTool_AllowList_TildeSubdirectory(t *testing.T) { require.NoError(t, os.MkdirAll(filepath.Join(homeDir, "projects"), 0o755)) wd := t.TempDir() - tool := NewFilesystemTool(wd, WithAllowList([]string{"~/projects"})) + tool := New(wd, WithAllowList([]string{"~/projects"})) // Inside the listed subdir. _, err := tool.resolveAndCheckPath(filepath.Join(homeDir, "projects", "app", "main.go")) @@ -101,7 +101,7 @@ func TestFilesystemTool_AllowList_MultipleRoots(t *testing.T) { wd := t.TempDir() otherDir := t.TempDir() - tool := NewFilesystemTool(wd, WithAllowList([]string{".", otherDir})) + tool := New(wd, WithAllowList([]string{".", otherDir})) _, err := tool.resolveAndCheckPath("file.txt") require.NoError(t, err) @@ -118,7 +118,7 @@ func TestFilesystemTool_AllowList_AbsolutePath(t *testing.T) { wd := t.TempDir() allowed := t.TempDir() - tool := NewFilesystemTool(wd, WithAllowList([]string{allowed})) + tool := New(wd, WithAllowList([]string{allowed})) // Absolute path inside the allowed root is fine. _, err := tool.resolveAndCheckPath(filepath.Join(allowed, "x", "y.txt")) @@ -135,7 +135,7 @@ func TestFilesystemTool_DenyList_RejectsMatchingPaths(t *testing.T) { denied := filepath.Join(wd, "secret") require.NoError(t, os.Mkdir(denied, 0o755)) - tool := NewFilesystemTool(wd, WithDenyList([]string{"secret"})) + tool := New(wd, WithDenyList([]string{"secret"})) // Anything under the denied subtree is rejected. _, err := tool.resolveAndCheckPath("secret/key.pem") @@ -158,7 +158,7 @@ func TestFilesystemTool_DenyList_TakesPrecedenceOverAllowList(t *testing.T) { require.NoError(t, os.MkdirAll(filepath.Join(wd, "src"), 0o755)) require.NoError(t, os.MkdirAll(filepath.Join(wd, "src", "vendor"), 0o755)) - tool := NewFilesystemTool(wd, + tool := New(wd, WithAllowList([]string{"."}), WithDenyList([]string{"src/vendor"})) @@ -182,7 +182,7 @@ func TestFilesystemTool_AllowList_SymlinkEscapeRejected(t *testing.T) { link := filepath.Join(wd, "escape") require.NoError(t, os.Symlink(target, link)) - tool := NewFilesystemTool(wd, WithAllowList([]string{"."})) + tool := New(wd, WithAllowList([]string{"."})) // Following the symlink escapes the allow-list and must be rejected. _, err := tool.resolveAndCheckPath("escape/secret.txt") @@ -201,7 +201,7 @@ func TestFilesystemTool_DenyList_SymlinkIntoDeniedAreaRejected(t *testing.T) { link := filepath.Join(wd, "shortcut") require.NoError(t, os.Symlink(denied, link)) - tool := NewFilesystemTool(wd, WithDenyList([]string{"secret"})) + tool := New(wd, WithDenyList([]string{"secret"})) // Reading via the symlink must still trigger the deny-list. _, err := tool.resolveAndCheckPath("shortcut/key.pem") @@ -212,7 +212,7 @@ func TestFilesystemTool_DenyList_SymlinkIntoDeniedAreaRejected(t *testing.T) { func TestFilesystemTool_AllowList_NewFilePath(t *testing.T) { t.Parallel() wd := t.TempDir() - tool := NewFilesystemTool(wd, WithAllowList([]string{"."})) + tool := New(wd, WithAllowList([]string{"."})) // A path that doesn't exist yet (e.g. about to be created by write_file) // must still be accepted when its lexical location is inside the allow-list. @@ -231,7 +231,7 @@ func TestFilesystemTool_AllowList_EmptyDisablesCheck(t *testing.T) { // nil and empty slice both leave the allow-list disabled. for _, roots := range [][]string{nil, {}} { - tool := NewFilesystemTool(tmpDir, WithAllowList(roots)) + tool := New(tmpDir, WithAllowList(roots)) _, err := tool.resolveAndCheckPath("/etc/hosts") require.NoError(t, err, "empty/nil allow-list must not constrain") } @@ -246,7 +246,7 @@ func TestFilesystemTool_HandlersUseAllowList(t *testing.T) { outsideFile := filepath.Join(other, "outside.txt") require.NoError(t, os.WriteFile(outsideFile, []byte("nope"), 0o644)) - tool := NewFilesystemTool(wd, WithAllowList([]string{"."})) + tool := New(wd, WithAllowList([]string{"."})) // read_file: must refuse the outside path. res, err := tool.handleReadFile(t.Context(), ReadFileArgs{Path: outsideFile}) @@ -315,7 +315,7 @@ func TestFilesystemTool_HandlersUseDenyList(t *testing.T) { require.NoError(t, os.Mkdir(filepath.Join(wd, "secrets"), 0o755)) require.NoError(t, os.WriteFile(filepath.Join(wd, "secrets", "key.pem"), []byte("k"), 0o644)) - tool := NewFilesystemTool(wd, WithDenyList([]string{"secrets"})) + tool := New(wd, WithDenyList([]string{"secrets"})) // edit_file: must refuse to read the file in a denied directory. res, err := tool.handleEditFile(t.Context(), EditFileArgs{ @@ -336,18 +336,18 @@ func TestFilesystemTool_Instructions_MentionsRestrictions(t *testing.T) { wd := t.TempDir() // Default instructions: no restriction text. - plain := NewFilesystemTool(wd).Instructions() + plain := New(wd).Instructions() assert.NotContains(t, plain, "restricted") assert.NotContains(t, plain, "must not access") // With an allow-list: instructions mention the restriction. - allowed := NewFilesystemTool(wd, WithAllowList([]string{".", "~"})).Instructions() + allowed := New(wd, WithAllowList([]string{".", "~"})).Instructions() assert.Contains(t, allowed, "restricted") assert.Contains(t, allowed, ".") assert.Contains(t, allowed, "~") // With a deny-list: instructions mention the deny entries. - denied := NewFilesystemTool(wd, WithDenyList([]string{"~/.ssh"})).Instructions() + denied := New(wd, WithDenyList([]string{"~/.ssh"})).Instructions() assert.Contains(t, denied, "must not access") assert.Contains(t, denied, "~/.ssh") } @@ -403,7 +403,7 @@ func TestWithAllowList_RejectsUndefinedEnvVar(t *testing.T) { // fail-closed: reject all operations when list construction fails. os.Unsetenv("DEFINITELY_NOT_SET") wd := t.TempDir() - tool := NewFilesystemTool(wd, WithAllowList([]string{"$DEFINITELY_NOT_SET"})) + tool := New(wd, WithAllowList([]string{"$DEFINITELY_NOT_SET"})) // The allow-list construction failed, so the toolset is disabled // (fail-closed). All operations must be rejected. @@ -422,7 +422,7 @@ func TestWithAllowList_AcceptsDefinedEnvVar(t *testing.T) { allowed := t.TempDir() t.Setenv("ALLOWED_DIR", allowed) - tool := NewFilesystemTool(wd, WithAllowList([]string{"$ALLOWED_DIR"})) + tool := New(wd, WithAllowList([]string{"$ALLOWED_DIR"})) // Inside the env-var-resolved root. _, err := tool.resolveAndCheckPath(filepath.Join(allowed, "file.txt")) @@ -441,7 +441,7 @@ func TestDenyList_NonExistentPath(t *testing.T) { resetHomeDir(t, homeDir) wd := t.TempDir() - tool := NewFilesystemTool(wd, WithDenyList([]string{"~/.ssh"})) + tool := New(wd, WithDenyList([]string{"~/.ssh"})) // ~/.ssh does not exist yet — a write to a path inside it must be // rejected before the directory is even created. diff --git a/pkg/tools/builtin/filesystem/filesystem_test.go b/pkg/tools/builtin/filesystem/filesystem_test.go index 60c4b2ab3..9c2139641 100644 --- a/pkg/tools/builtin/filesystem/filesystem_test.go +++ b/pkg/tools/builtin/filesystem/filesystem_test.go @@ -37,7 +37,7 @@ func initGitRepo(t *testing.T, dir string) { func TestFilesystemTool_DisplayNames(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) all, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -51,7 +51,7 @@ func TestFilesystemTool_DisplayNames(t *testing.T) { func TestFilesystemTool_ResolvePath(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) // Test relative path within working directory resolvedPath := tool.resolvePath("subdir/file.txt") @@ -76,7 +76,7 @@ func TestFilesystemTool_ResolvePath_ExpandsTilde(t *testing.T) { homeDir := t.TempDir() resetHomeDir(t, homeDir) wd := t.TempDir() - tool := NewFilesystemTool(wd) + tool := New(wd) assert.Equal(t, homeDir, tool.resolvePath("~")) assert.Equal(t, filepath.Join(homeDir, "file.txt"), tool.resolvePath("~/file.txt")) @@ -95,7 +95,7 @@ func TestFilesystemTool_ReadFile_TildePath(t *testing.T) { homeDir := t.TempDir() resetHomeDir(t, homeDir) wd := t.TempDir() - tool := NewFilesystemTool(wd) + tool := New(wd) content := "hello from home" require.NoError(t, os.WriteFile(filepath.Join(homeDir, "note.txt"), []byte(content), 0o644)) @@ -109,7 +109,7 @@ func TestFilesystemTool_ReadFile_TildePath(t *testing.T) { func TestFilesystemTool_WriteFile(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) testFile := "test.txt" content := "Hello, World!" @@ -129,7 +129,7 @@ func TestFilesystemTool_WriteFile(t *testing.T) { func TestFilesystemTool_WriteFile_NestedDirectory(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) nestedFile := "a/b/c/test.txt" content := "Hello, nested world!" @@ -155,7 +155,7 @@ func TestFilesystemTool_WriteFile_NestedDirectory(t *testing.T) { func TestFilesystemTool_ReadFile(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) testFile := "test.txt" content := "Hello, World!" @@ -177,7 +177,7 @@ func TestFilesystemTool_ReadFile(t *testing.T) { func TestFilesystemTool_ReadImageFile(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) // Create a valid PNG file using Go's image library. pngData := createTestPNG(t, 10, 10) @@ -211,7 +211,7 @@ func TestFilesystemTool_ReadImageFile(t *testing.T) { func TestFilesystemTool_ReadMultipleFiles(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) file1 := "file1.txt" file2 := "file2.txt" @@ -241,7 +241,7 @@ func TestFilesystemTool_ReadMultipleFiles(t *testing.T) { func TestFilesystemTool_ListDirectory(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) testFile := "test.txt" testDir := "testdir" @@ -266,7 +266,7 @@ func TestFilesystemTool_ListDirectory(t *testing.T) { func TestFilesystemTool_EditFile(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) testFile := "test.txt" originalContent := "Hello World\nThis is a test\nGoodbye World" @@ -447,7 +447,7 @@ func TestParseEditFileArgs(t *testing.T) { func TestFilesystemTool_SearchFilesContent(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) file1Content := "This is a test file\nwith multiple lines\ncontaining test data" file2Content := "Another file\nwith different content\nno matching terms here" @@ -503,7 +503,7 @@ func main() { Cmd: "touch $file.formatted", }, } - tool := NewFilesystemTool(tmpDir, WithPostEditCommands(postEditConfigs)) + tool := New(tmpDir, WithPostEditCommands(postEditConfigs)) formattedFile := filepath.Join(tmpDir, testFile+".formatted") t.Run("write_file", func(t *testing.T) { @@ -620,7 +620,7 @@ func TestMatchExcludePattern(t *testing.T) { func TestFilesystemTool_OutputSchema(t *testing.T) { tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) allTools, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -640,7 +640,7 @@ func TestFilesystemTool_IgnoreVCS_Default(t *testing.T) { require.NoError(t, os.WriteFile(filepath.Join(gitDir, "config"), []byte("git config"), 0o644)) require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "test.txt"), []byte("findme"), 0o644)) - tool := NewFilesystemTool(tmpDir, WithIgnoreVCS(true)) + tool := New(tmpDir, WithIgnoreVCS(true)) result, err := tool.handleSearchFilesContent(t.Context(), SearchFilesContentArgs{ Path: ".", Query: "findme", @@ -659,7 +659,7 @@ func TestFilesystemTool_IgnoreVCS_Disabled(t *testing.T) { require.NoError(t, os.WriteFile(filepath.Join(gitDir, "config"), []byte("findme"), 0o644)) require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "test.txt"), []byte("findme"), 0o644)) - tool := NewFilesystemTool(tmpDir, WithIgnoreVCS(false)) + tool := New(tmpDir, WithIgnoreVCS(false)) result, err := tool.handleSearchFilesContent(t.Context(), SearchFilesContentArgs{ Path: ".", Query: "findme", @@ -694,7 +694,7 @@ temp_* require.NoError(t, os.Mkdir(filepath.Join(tmpDir, "build"), 0o755)) require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "build", "output.js"), []byte("findme"), 0o644)) - tool := NewFilesystemTool(tmpDir, WithIgnoreVCS(true)) + tool := New(tmpDir, WithIgnoreVCS(true)) result, err := tool.handleSearchFilesContent(t.Context(), SearchFilesContentArgs{ Path: ".", Query: "findme", @@ -717,7 +717,7 @@ func TestFilesystemTool_SearchContent_WithGitignore(t *testing.T) { require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "source.txt"), []byte("findme"), 0o644)) require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "debug.log"), []byte("findme"), 0o644)) - tool := NewFilesystemTool(tmpDir, WithIgnoreVCS(true)) + tool := New(tmpDir, WithIgnoreVCS(true)) result, err := tool.handleSearchFilesContent(t.Context(), SearchFilesContentArgs{ Path: ".", Query: "findme", @@ -737,7 +737,7 @@ func TestFilesystemTool_ListDirectory_IgnoresVCS(t *testing.T) { require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("test"), 0o644)) require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("test"), 0o644)) - tool := NewFilesystemTool(tmpDir, WithIgnoreVCS(true)) + tool := New(tmpDir, WithIgnoreVCS(true)) result, err := tool.handleListDirectory(t.Context(), ListDirectoryArgs{ Path: ".", }) @@ -765,7 +765,7 @@ func TestFilesystemTool_SubdirectoryGitignorePatterns(t *testing.T) { require.NoError(t, os.WriteFile(filepath.Join(subDir, "sub.log"), []byte("findme"), 0o644)) // ignored by root .gitignore require.NoError(t, os.WriteFile(filepath.Join(subDir, "sub.tmp"), []byte("findme"), 0o644)) // ignored by subdir .gitignore - tool := NewFilesystemTool(tmpDir, WithIgnoreVCS(true)) + tool := New(tmpDir, WithIgnoreVCS(true)) result, err := tool.handleSearchFilesContent(t.Context(), SearchFilesContentArgs{ Path: ".", Query: "findme", @@ -791,7 +791,7 @@ func TestFilesystemTool_DirectoryTree_IgnoresVCS(t *testing.T) { require.NoError(t, os.Mkdir(srcDir, 0o755)) require.NoError(t, os.WriteFile(filepath.Join(srcDir, "main.go"), []byte("package main"), 0o644)) - tool := NewFilesystemTool(tmpDir, WithIgnoreVCS(true)) + tool := New(tmpDir, WithIgnoreVCS(true)) result, err := tool.handleDirectoryTree(t.Context(), DirectoryTreeArgs{ Path: ".", }) @@ -814,7 +814,7 @@ func TestFilesystemTool_DirectoryTree_IgnoresVCS(t *testing.T) { func TestFilesystemTool_EmptyWorkingDir(t *testing.T) { t.Parallel() - tool := NewFilesystemTool("") + tool := New("") // With empty working dir, relative paths are resolved relative to current directory resolvedPath := tool.resolvePath("test.txt") @@ -828,7 +828,7 @@ func TestFilesystemTool_EmptyWorkingDir(t *testing.T) { func TestFilesystemTool_CreateDirectory(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) result, err := tool.handleCreateDirectory(t.Context(), CreateDirectoryArgs{ Paths: []string{"newdir"}, @@ -841,7 +841,7 @@ func TestFilesystemTool_CreateDirectory(t *testing.T) { func TestFilesystemTool_CreateDirectory_Multiple(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) result, err := tool.handleCreateDirectory(t.Context(), CreateDirectoryArgs{ Paths: []string{"dir1", "dir2", "dir3"}, @@ -858,7 +858,7 @@ func TestFilesystemTool_CreateDirectory_Multiple(t *testing.T) { func TestFilesystemTool_CreateDirectory_Nested(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) result, err := tool.handleCreateDirectory(t.Context(), CreateDirectoryArgs{ Paths: []string{"a/b/c"}, @@ -871,7 +871,7 @@ func TestFilesystemTool_CreateDirectory_Nested(t *testing.T) { func TestFilesystemTool_CreateDirectory_AlreadyExists(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) require.NoError(t, os.Mkdir(filepath.Join(tmpDir, "existing"), 0o755)) @@ -885,7 +885,7 @@ func TestFilesystemTool_CreateDirectory_AlreadyExists(t *testing.T) { func TestFilesystemTool_RemoveDirectory(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) dirPath := filepath.Join(tmpDir, "toremove") require.NoError(t, os.Mkdir(dirPath, 0o755)) @@ -901,7 +901,7 @@ func TestFilesystemTool_RemoveDirectory(t *testing.T) { func TestFilesystemTool_RemoveDirectory_Multiple(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) dir1 := filepath.Join(tmpDir, "dir1") dir2 := filepath.Join(tmpDir, "dir2") @@ -921,7 +921,7 @@ func TestFilesystemTool_RemoveDirectory_Multiple(t *testing.T) { func TestFilesystemTool_RemoveDirectory_NotEmpty(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) dirPath := filepath.Join(tmpDir, "notempty") require.NoError(t, os.Mkdir(dirPath, 0o755)) @@ -939,7 +939,7 @@ func TestFilesystemTool_RemoveDirectory_NotEmpty(t *testing.T) { func TestFilesystemTool_RemoveDirectory_NotExists(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) result, err := tool.handleRemoveDirectory(t.Context(), RemoveDirectoryArgs{ Paths: []string{"nonexistent"}, @@ -952,7 +952,7 @@ func TestFilesystemTool_RemoveDirectory_NotExists(t *testing.T) { func TestFilesystemTool_RemoveDirectory_IsFile(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "file.txt"), []byte("content"), 0o644)) @@ -967,7 +967,7 @@ func TestFilesystemTool_RemoveDirectory_IsFile(t *testing.T) { func TestFilesystemTool_RemoveDirectory_MultipleStopsOnError(t *testing.T) { t.Parallel() tmpDir := t.TempDir() - tool := NewFilesystemTool(tmpDir) + tool := New(tmpDir) dir1 := filepath.Join(tmpDir, "dir1") dir3 := filepath.Join(tmpDir, "dir3") @@ -1038,7 +1038,7 @@ func TestFilesystemTool_RootedWriteRefusesSymlinkSwap(t *testing.T) { secret := filepath.Join(outside, "secret.txt") require.NoError(t, os.WriteFile(secret, []byte("untouched"), 0o600)) - tool := NewFilesystemTool(wd, WithAllowList([]string{"."})) + tool := New(wd, WithAllowList([]string{"."})) t.Cleanup(func() { _ = tool.Close() }) // Step 1: validate the path while the file is legitimate. This is @@ -1079,7 +1079,7 @@ func TestFilesystemTool_RootedReadRefusesSymlinkSwap(t *testing.T) { secret := filepath.Join(outside, "secret.txt") require.NoError(t, os.WriteFile(secret, []byte("CONFIDENTIAL"), 0o600)) - tool := NewFilesystemTool(wd, WithAllowList([]string{"."})) + tool := New(wd, WithAllowList([]string{"."})) t.Cleanup(func() { _ = tool.Close() }) resolved, err := tool.resolveAndCheckPath("report.txt") @@ -1109,7 +1109,7 @@ func TestFilesystemTool_StaticCheckRejectsExistingSymlink(t *testing.T) { require.NoError(t, os.WriteFile(secret, []byte("CONFIDENTIAL"), 0o600)) require.NoError(t, os.Symlink(secret, filepath.Join(wd, "report.txt"))) - tool := NewFilesystemTool(wd, WithAllowList([]string{"."})) + tool := New(wd, WithAllowList([]string{"."})) t.Cleanup(func() { _ = tool.Close() }) _, err := tool.resolveAndCheckPath("report.txt") @@ -1133,7 +1133,7 @@ func TestFilesystemTool_RootedListDirRefusesSymlinkSwap(t *testing.T) { subdir := filepath.Join(wd, "sub") require.NoError(t, os.MkdirAll(subdir, 0o755)) - tool := NewFilesystemTool(wd, WithAllowList([]string{"."})) + tool := New(wd, WithAllowList([]string{"."})) t.Cleanup(func() { _ = tool.Close() }) resolved, err := tool.resolveAndCheckPath("sub") diff --git a/pkg/tools/builtin/handoff/handoff.go b/pkg/tools/builtin/handoff/handoff.go index 1cceb239e..1d8aa2027 100644 --- a/pkg/tools/builtin/handoff/handoff.go +++ b/pkg/tools/builtin/handoff/handoff.go @@ -8,19 +8,19 @@ import ( const ToolNameHandoff = "handoff" -type Tool struct{} +type ToolSet struct{} -var _ tools.ToolSet = (*Tool)(nil) +var _ tools.ToolSet = (*ToolSet)(nil) type Args struct { Agent string `json:"agent" jsonschema:"The name of the agent to hand off the conversation to."` } -func NewHandoffTool() *Tool { - return &Tool{} +func New() *ToolSet { + return &ToolSet{} } -func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { +func (t *ToolSet) Tools(context.Context) ([]tools.Tool, error) { return []tools.Tool{ { Name: ToolNameHandoff, diff --git a/pkg/tools/builtin/lsp/lsp.go b/pkg/tools/builtin/lsp/lsp.go index 39b2e9606..c20d90f3c 100644 --- a/pkg/tools/builtin/lsp/lsp.go +++ b/pkg/tools/builtin/lsp/lsp.go @@ -20,8 +20,13 @@ import ( "time" "github.com/docker/docker-agent/pkg/concurrent" + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/toolinstall" "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/tools/lifecycle" + "github.com/docker/docker-agent/pkg/tools/workingdir" ) const ( @@ -42,18 +47,18 @@ const ( ToolNameLSPInlayHints = "lsp_inlay_hints" ) -// Tool implements tools.ToolSet for connecting to any LSP server. +// ToolSet implements tools.ToolSet for connecting to any LSP server. // It provides stateless code intelligence tools that automatically manage // the LSP server lifecycle and document state. -type Tool struct { +type ToolSet struct { handler *lspHandler } // Verify interface compliance var ( - _ tools.ToolSet = (*Tool)(nil) - _ tools.Startable = (*Tool)(nil) - _ tools.Instructable = (*Tool)(nil) + _ tools.ToolSet = (*ToolSet)(nil) + _ tools.Startable = (*ToolSet)(nil) + _ tools.Instructable = (*ToolSet)(nil) ) type lspHandler struct { @@ -339,13 +344,41 @@ type lspInlayHint struct { PaddingRight bool `json:"paddingRight,omitempty"` } -// NewLSPTool creates a new LSP tool that connects to an LSP server. +// CreateToolSet is used by the tools registry. +func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { + resolvedCommand, err := toolinstall.EnsureCommand(ctx, toolset.Command, toolset.Version) + if err != nil { + return nil, fmt.Errorf("resolving command %q: %w", toolset.Command, err) + } + + env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), runConfig.EnvProvider()) + if err != nil { + return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err) + } + env = append(env, os.Environ()...) + env = toolinstall.PrependBinDirToEnv(env) + + cwd := workingdir.Resolve(toolset.WorkingDir, runConfig.WorkingDir) + if toolset.WorkingDir != "" { + if err := workingdir.CheckDirExists(cwd, "lsp"); err != nil { + return nil, err + } + } + + tool := New(resolvedCommand, toolset.Args, env, cwd, lifecycle.PolicyFromConfig(toolset.Name, toolset.Lifecycle)) + if len(toolset.FileTypes) > 0 { + tool.SetFileTypes(toolset.FileTypes) + } + return tool, nil +} + +// New creates a new LSP toolset that connects to an LSP server. // // The optional policy lets callers tune restart/backoff behaviour. When // the zero value is passed the supervisor uses its built-in defaults // (RestartOnFailure, 5 attempts, 1s..32s backoff). Internal callbacks // (OnDisconnect, Logger) are always set by the constructor. -func NewLSPTool(command string, args, env []string, workingDir string, policy ...lifecycle.Policy) *Tool { +func New(command string, args, env []string, workingDir string, policy ...lifecycle.Policy) *ToolSet { h := &lspHandler{ command: command, args: args, @@ -367,36 +400,36 @@ func NewLSPTool(command string, args, env []string, workingDir string, policy .. h.diagnosticsMu.Unlock() } h.supervisor = lifecycle.New("lsp/"+command, &lspConnector{h: h}, base) - return &Tool{handler: h} + return &ToolSet{handler: h} } // SetFileTypes sets the file types (extensions) that this LSP server handles. -func (t *Tool) SetFileTypes(fileTypes []string) { +func (t *ToolSet) SetFileTypes(fileTypes []string) { t.handler.fileTypes = fileTypes } // WorkingDir returns the working directory of the LSP server process. // This is intended for testing only. -func (t *Tool) WorkingDir() string { +func (t *ToolSet) WorkingDir() string { return t.handler.workingDir } // HandlesFile checks if this LSP handles the given file based on its extension. -func (t *Tool) HandlesFile(path string) bool { +func (t *ToolSet) HandlesFile(path string) bool { return t.handler.handlesFile(path) } -func (t *Tool) Start(ctx context.Context) error { +func (t *ToolSet) Start(ctx context.Context) error { return t.handler.supervisor.Start(ctx) } -func (t *Tool) Stop(ctx context.Context) error { +func (t *ToolSet) Stop(ctx context.Context) error { return t.handler.supervisor.Stop(ctx) } // State returns a snapshot of the underlying supervisor's lifecycle state, // suitable for the /tools dialog and lifecycle log messages. -func (t *Tool) State() lifecycle.StateInfo { +func (t *ToolSet) State() lifecycle.StateInfo { return t.handler.supervisor.State() } @@ -404,7 +437,7 @@ func (t *Tool) State() lifecycle.StateInfo { // Stopped supervisors are recovered via Start; otherwise the current // session is dropped and we wait for the supervisor to reconnect. // Blocks up to 35s (matching the MCP toolset). -func (t *Tool) Restart(ctx context.Context) error { +func (t *ToolSet) Restart(ctx context.Context) error { if t.handler.supervisor.State().State.IsTerminal() { return t.handler.supervisor.Start(ctx) } @@ -414,17 +447,17 @@ func (t *Tool) Restart(ctx context.Context) error { // Kind returns the user-facing classification of this toolset. Used by // status surfaces such as the /tools dialog so they can label the // toolset without leaking Go type names. -func (t *Tool) Kind() string { return "LSP" } +func (t *ToolSet) Kind() string { return "LSP" } // Name returns the basename of the configured command ("gopls", // "rust-analyzer", …). It's the most useful identifier in the absence // of a YAML name: field on LSP toolsets, and lets the /tools dialog // distinguish multiple language servers in the same agent. -func (t *Tool) Name() string { +func (t *ToolSet) Name() string { return filepath.Base(t.handler.command) } -func (t *Tool) Instructions() string { +func (t *ToolSet) Instructions() string { return `# LSP Code Intelligence Tools Stateless code intelligence tools via Language Server Protocol. Just provide file path and position. @@ -473,7 +506,7 @@ func lspTool(name, title, description string, readOnly bool, params any, handler } } -func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { +func (t *ToolSet) Tools(context.Context) ([]tools.Tool, error) { h := t.handler all := allLSPTools(h) @@ -499,7 +532,7 @@ func (h *lspHandler) snapshotCapabilities() *lspServerCapabilities { // supervisor reaches Ready and the server's capability matrix becomes // available. The runtime uses this to re-query Tools() and pick up the // capability-filtered list. -func (t *Tool) SetToolsChangedHandler(handler func()) { +func (t *ToolSet) SetToolsChangedHandler(handler func()) { t.handler.mu.Lock() t.handler.toolsChangedHandler = handler t.handler.mu.Unlock() diff --git a/pkg/tools/builtin/lsp/lsp_capabilities_test.go b/pkg/tools/builtin/lsp/lsp_capabilities_test.go index 08aeece53..95e3ef008 100644 --- a/pkg/tools/builtin/lsp/lsp_capabilities_test.go +++ b/pkg/tools/builtin/lsp/lsp_capabilities_test.go @@ -111,7 +111,7 @@ func TestIsProviderEnabled(t *testing.T) { func TestSetToolsChangedHandler_RegisterAndFire(t *testing.T) { t.Parallel() - tool := NewLSPTool("nope", nil, nil, t.TempDir()) + tool := New("nope", nil, nil, t.TempDir()) called := 0 tool.SetToolsChangedHandler(func() { called++ }) diff --git a/pkg/tools/builtin/lsp/lsp_lifecycle_test.go b/pkg/tools/builtin/lsp/lsp_lifecycle_test.go index 34e2be420..9ae559772 100644 --- a/pkg/tools/builtin/lsp/lsp_lifecycle_test.go +++ b/pkg/tools/builtin/lsp/lsp_lifecycle_test.go @@ -17,7 +17,7 @@ func TestLSPTool_StartFailureWhenServerMissing(t *testing.T) { t.Parallel() // Use a binary path that surely does not exist anywhere on PATH. - tool := NewLSPTool("docker-agent-lsp-does-not-exist", nil, nil, t.TempDir()) + tool := New("docker-agent-lsp-does-not-exist", nil, nil, t.TempDir()) err := tool.Start(t.Context()) require.Error(t, err) assert.ErrorIs(t, err, lifecycle.ErrServerUnavailable) @@ -28,7 +28,7 @@ func TestLSPTool_StartFailureWhenServerMissing(t *testing.T) { func TestLSPTool_StopBeforeStart(t *testing.T) { t.Parallel() - tool := NewLSPTool("docker-agent-lsp-does-not-exist", nil, nil, t.TempDir()) + tool := New("docker-agent-lsp-does-not-exist", nil, nil, t.TempDir()) require.NoError(t, tool.Stop(t.Context())) require.NoError(t, tool.Stop(t.Context())) } @@ -40,7 +40,7 @@ func TestLSPTool_StopBeforeStart(t *testing.T) { func TestLSPTool_SupervisorRetryAfterFailure(t *testing.T) { t.Parallel() - tool := NewLSPTool("docker-agent-lsp-does-not-exist", nil, nil, t.TempDir()) + tool := New("docker-agent-lsp-does-not-exist", nil, nil, t.TempDir()) // First attempt: Start fails because the binary is missing. err := tool.Start(t.Context()) diff --git a/pkg/tools/builtin/lsp/lsp_multiplexer.go b/pkg/tools/builtin/lsp/lsp_multiplexer.go index deb742847..bd335ac49 100644 --- a/pkg/tools/builtin/lsp/lsp_multiplexer.go +++ b/pkg/tools/builtin/lsp/lsp_multiplexer.go @@ -22,13 +22,13 @@ type Multiplexer struct { // optionally-wrapped ToolSet (used for tool enumeration, so that per-toolset // config like tool filters, instructions, or toon wrappers are respected). type Backend struct { - LSP *Tool + LSP *ToolSet Toolset tools.ToolSet } // lspRouteTarget pairs a backend with the tool handler it produced for a given tool name. type lspRouteTarget struct { - lsp *Tool + lsp *ToolSet handler tools.ToolHandler } diff --git a/pkg/tools/builtin/lsp/lsp_multiplexer_test.go b/pkg/tools/builtin/lsp/lsp_multiplexer_test.go index 3d4603b0c..096dec6e3 100644 --- a/pkg/tools/builtin/lsp/lsp_multiplexer_test.go +++ b/pkg/tools/builtin/lsp/lsp_multiplexer_test.go @@ -12,11 +12,11 @@ import ( ) // newTestMultiplexer creates a multiplexer with a Go and Python backend. -func newTestMultiplexer() (*Multiplexer, *Tool) { - goTool := NewLSPTool("gopls", nil, nil, "/tmp") +func newTestMultiplexer() (*Multiplexer, *ToolSet) { + goTool := New("gopls", nil, nil, "/tmp") goTool.SetFileTypes([]string{".go", ".mod"}) - pyTool := NewLSPTool("pyright", nil, nil, "/tmp") + pyTool := New("pyright", nil, nil, "/tmp") pyTool.SetFileTypes([]string{".py"}) mux := NewLSPMultiplexer([]Backend{ diff --git a/pkg/tools/builtin/lsp/lsp_test.go b/pkg/tools/builtin/lsp/lsp_test.go index 457d54197..8b1607d71 100644 --- a/pkg/tools/builtin/lsp/lsp_test.go +++ b/pkg/tools/builtin/lsp/lsp_test.go @@ -9,10 +9,10 @@ import ( "github.com/stretchr/testify/require" ) -func TestNewLSPTool(t *testing.T) { +func TestNew(t *testing.T) { t.Parallel() - tool := NewLSPTool("gopls", []string{}, nil, "/tmp") + tool := New("gopls", []string{}, nil, "/tmp") require.NotNil(t, tool) require.NotNil(t, tool.handler) } @@ -20,7 +20,7 @@ func TestNewLSPTool(t *testing.T) { func TestLSPTool_Tools(t *testing.T) { t.Parallel() - tool := NewLSPTool("gopls", nil, nil, "/tmp") + tool := New("gopls", nil, nil, "/tmp") tools, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -59,7 +59,7 @@ func TestLSPTool_Tools(t *testing.T) { func TestLSPTool_ToolDescriptions(t *testing.T) { t.Parallel() - tool := NewLSPTool("gopls", nil, nil, "/tmp") + tool := New("gopls", nil, nil, "/tmp") tools, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -76,7 +76,7 @@ func TestLSPTool_ToolDescriptions(t *testing.T) { func TestLSPTool_Instructions(t *testing.T) { t.Parallel() - tool := NewLSPTool("gopls", nil, nil, "/tmp") + tool := New("gopls", nil, nil, "/tmp") instructions := tool.Instructions() // Should mention the tools are stateless @@ -250,7 +250,7 @@ func TestFormatHoverContents(t *testing.T) { func TestLSPHandler_NotInitialized_AutoInitializes(t *testing.T) { t.Parallel() - tool := NewLSPTool("nonexistent-lsp-server", nil, nil, "/tmp") + tool := New("nonexistent-lsp-server", nil, nil, "/tmp") // Test that operations attempt auto-initialization // (will fail because the server doesn't exist, but should try) @@ -270,7 +270,7 @@ func TestLSPHandler_NotInitialized_AutoInitializes(t *testing.T) { func TestLSPHandler_GetDiagnostics_NoDiagnostics(t *testing.T) { t.Parallel() - tool := NewLSPTool("gopls", nil, nil, "/tmp") + tool := New("gopls", nil, nil, "/tmp") // Mark as initialized to test the diagnostics retrieval path tool.handler.initialized.Store(true) // Pretend we have a running server by setting a non-nil cmd @@ -291,7 +291,7 @@ func TestLSPHandler_GetDiagnostics_NoDiagnostics(t *testing.T) { func TestLSPHandler_GetDiagnostics_WithDiagnostics(t *testing.T) { t.Parallel() - tool := NewLSPTool("gopls", nil, nil, "/tmp") + tool := New("gopls", nil, nil, "/tmp") // Mark as initialized to test the diagnostics retrieval path tool.handler.initialized.Store(true) // Pretend we have a running server @@ -318,7 +318,7 @@ func TestLSPHandler_GetDiagnostics_WithDiagnostics(t *testing.T) { func TestProcessNotification_Diagnostics(t *testing.T) { t.Parallel() - tool := NewLSPTool("gopls", nil, nil, "/tmp") + tool := New("gopls", nil, nil, "/tmp") // Create a diagnostic notification notification := map[string]any{ @@ -356,7 +356,7 @@ func TestProcessNotification_Diagnostics(t *testing.T) { func TestLSPHandler_Stop_NotStarted(t *testing.T) { t.Parallel() - tool := NewLSPTool("gopls", nil, nil, "/tmp") + tool := New("gopls", nil, nil, "/tmp") ctx := t.Context() // Should not error when stopping a non-started server @@ -410,13 +410,13 @@ func TestLSPTool_HandlesFile(t *testing.T) { t.Parallel() // Without file type filter (handles all) - tool := NewLSPTool("gopls", nil, nil, "/tmp") + tool := New("gopls", nil, nil, "/tmp") assert.True(t, tool.HandlesFile("main.go")) assert.True(t, tool.HandlesFile("app.py")) assert.True(t, tool.HandlesFile("anything.txt")) // With file type filter - toolFiltered := NewLSPTool("gopls", nil, nil, "/tmp") + toolFiltered := New("gopls", nil, nil, "/tmp") toolFiltered.SetFileTypes([]string{".go", ".mod"}) assert.True(t, toolFiltered.HandlesFile("main.go")) assert.True(t, toolFiltered.HandlesFile("go.mod")) @@ -424,7 +424,7 @@ func TestLSPTool_HandlesFile(t *testing.T) { assert.False(t, toolFiltered.HandlesFile("index.js")) // Without leading dot in filter - toolNoDot := NewLSPTool("gopls", nil, nil, "/tmp") + toolNoDot := New("gopls", nil, nil, "/tmp") toolNoDot.SetFileTypes([]string{"go", "py"}) assert.True(t, toolNoDot.HandlesFile("main.go")) assert.True(t, toolNoDot.HandlesFile("app.py")) @@ -442,7 +442,7 @@ func TestPathToURI(t *testing.T) { func TestLSPHandler_IsFileOpen(t *testing.T) { t.Parallel() - tool := NewLSPTool("gopls", nil, nil, "/tmp") + tool := New("gopls", nil, nil, "/tmp") // Initially no files are open assert.False(t, tool.handler.isFileOpen("file:///test.go")) @@ -459,7 +459,7 @@ func TestLSPHandler_IsFileOpen(t *testing.T) { func TestLSPHandler_DiagnosticsVersion(t *testing.T) { t.Parallel() - tool := NewLSPTool("gopls", nil, nil, "/tmp") + tool := New("gopls", nil, nil, "/tmp") // Initial version should be 0 assert.Equal(t, int64(0), tool.handler.diagnosticsVersion.Load()) @@ -751,7 +751,7 @@ func TestCapabilityStatus(t *testing.T) { func TestLSPHandler_Workspace(t *testing.T) { t.Parallel() - tool := NewLSPTool("gopls", []string{"-remote=auto"}, nil, "/tmp/project") + tool := New("gopls", []string{"-remote=auto"}, nil, "/tmp/project") tool.SetFileTypes([]string{".go", ".mod"}) // Mark as initialized and set server info/capabilities diff --git a/pkg/tools/builtin/memory/memory.go b/pkg/tools/builtin/memory/memory.go index 72b61a6ff..1e200346e 100644 --- a/pkg/tools/builtin/memory/memory.go +++ b/pkg/tools/builtin/memory/memory.go @@ -4,11 +4,18 @@ import ( "context" "encoding/json" "fmt" + "os" + "path/filepath" "strconv" "time" + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/memory/database" + "github.com/docker/docker-agent/pkg/memory/database/sqlite" + "github.com/docker/docker-agent/pkg/paths" "github.com/docker/docker-agent/pkg/tools" + "github.com/docker/docker-agent/pkg/tools/toolsetpath" ) const ( @@ -27,35 +34,64 @@ type DB interface { UpdateMemory(ctx context.Context, memory database.UserMemory) error } -type Tool struct { +type ToolSet struct { db DB path string } // Verify interface compliance var ( - _ tools.ToolSet = (*Tool)(nil) - _ tools.Describer = (*Tool)(nil) - _ tools.Instructable = (*Tool)(nil) + _ tools.ToolSet = (*ToolSet)(nil) + _ tools.Describer = (*ToolSet)(nil) + _ tools.Instructable = (*ToolSet)(nil) ) -func NewMemoryTool(manager DB) *Tool { - return &Tool{ +// CreateToolSet is used by the tools registry. +func CreateToolSet(toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, configName string) (tools.ToolSet, error) { + var validatedMemoryPath string + + if toolset.Path != "" { + var err error + validatedMemoryPath, err = toolsetpath.Resolve(toolset.Path, parentDir, runConfig) + if err != nil { + return nil, fmt.Errorf("invalid memory database path: %w", err) + } + } else { + if configName == "" { + configName = "default" + } + validatedMemoryPath = filepath.Join(paths.GetDataDir(), "memory", configName, "memory.db") + } + + if err := os.MkdirAll(filepath.Dir(validatedMemoryPath), 0o700); err != nil { + return nil, fmt.Errorf("failed to create memory database directory: %w", err) + } + + db, err := sqlite.NewMemoryDatabase(validatedMemoryPath) + if err != nil { + return nil, fmt.Errorf("failed to create memory database: %w", err) + } + + return NewWithPath(db, validatedMemoryPath), nil +} + +func New(manager DB) *ToolSet { + return &ToolSet{ db: manager, } } -// NewMemoryToolWithPath creates a Tool and records the database path for +// NewWithPath creates a ToolSet and records the database path for // user-visible identification in warnings and error messages. -func NewMemoryToolWithPath(manager DB, dbPath string) *Tool { - return &Tool{ +func NewWithPath(manager DB, dbPath string) *ToolSet { + return &ToolSet{ db: manager, path: dbPath, } } // Describe returns a short, user-visible description of this toolset instance. -func (t *Tool) Describe() string { +func (t *ToolSet) Describe() string { if t.path != "" { return "memory(path=" + t.path + ")" } @@ -82,7 +118,7 @@ type UpdateMemoryArgs struct { Category string `json:"category,omitempty" jsonschema:"Optional new category for the memory"` } -func (t *Tool) Instructions() string { +func (t *ToolSet) Instructions() string { return `## Memory Tools Check stored memories for relevant context before acting. Store useful information silently — never mention using this tool. @@ -93,7 +129,7 @@ Check stored memories for relevant context before acting. Store useful informati - Organize with categories: "preference", "fact", "project", "decision"` } -func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { +func (t *ToolSet) Tools(context.Context) ([]tools.Tool, error) { return []tools.Tool{ { Name: ToolNameAddMemory, @@ -154,7 +190,7 @@ func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { }, nil } -func (t *Tool) handleAddMemory(ctx context.Context, args AddMemoryArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) handleAddMemory(ctx context.Context, args AddMemoryArgs) (*tools.ToolCallResult, error) { memory := database.UserMemory{ ID: strconv.FormatInt(time.Now().UnixNano(), 10), CreatedAt: time.Now().Format(time.RFC3339), @@ -169,7 +205,7 @@ func (t *Tool) handleAddMemory(ctx context.Context, args AddMemoryArgs) (*tools. return tools.ResultSuccess("Memory added successfully with ID: " + memory.ID), nil } -func (t *Tool) handleGetMemories(ctx context.Context, _ map[string]any) (*tools.ToolCallResult, error) { +func (t *ToolSet) handleGetMemories(ctx context.Context, _ map[string]any) (*tools.ToolCallResult, error) { memories, err := t.db.GetMemories(ctx) if err != nil { return nil, fmt.Errorf("failed to get memories: %w", err) @@ -183,7 +219,7 @@ func (t *Tool) handleGetMemories(ctx context.Context, _ map[string]any) (*tools. return tools.ResultSuccess(string(result)), nil } -func (t *Tool) handleDeleteMemory(ctx context.Context, args DeleteMemoryArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) handleDeleteMemory(ctx context.Context, args DeleteMemoryArgs) (*tools.ToolCallResult, error) { memory := database.UserMemory{ ID: args.ID, } @@ -195,7 +231,7 @@ func (t *Tool) handleDeleteMemory(ctx context.Context, args DeleteMemoryArgs) (* return tools.ResultSuccess(fmt.Sprintf("Memory with ID %s deleted successfully", args.ID)), nil } -func (t *Tool) handleSearchMemories(ctx context.Context, args SearchMemoriesArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) handleSearchMemories(ctx context.Context, args SearchMemoriesArgs) (*tools.ToolCallResult, error) { memories, err := t.db.SearchMemories(ctx, args.Query, args.Category) if err != nil { return nil, fmt.Errorf("failed to search memories: %w", err) @@ -209,7 +245,7 @@ func (t *Tool) handleSearchMemories(ctx context.Context, args SearchMemoriesArgs return tools.ResultSuccess(string(result)), nil } -func (t *Tool) handleUpdateMemory(ctx context.Context, args UpdateMemoryArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) handleUpdateMemory(ctx context.Context, args UpdateMemoryArgs) (*tools.ToolCallResult, error) { memory := database.UserMemory{ ID: args.ID, Memory: args.Memory, diff --git a/pkg/tools/builtin/memory/memory_test.go b/pkg/tools/builtin/memory/memory_test.go index cc2d6ba16..3dea76537 100644 --- a/pkg/tools/builtin/memory/memory_test.go +++ b/pkg/tools/builtin/memory/memory_test.go @@ -46,7 +46,7 @@ func (m *MockDB) UpdateMemory(ctx context.Context, memory database.UserMemory) e func TestMemoryTool_Instructions(t *testing.T) { manager := new(MockDB) - tool := NewMemoryTool(manager) + tool := New(manager) instructions := tool.Instructions() assert.Contains(t, instructions, "Memory Tools") @@ -57,7 +57,7 @@ func TestMemoryTool_Instructions(t *testing.T) { func TestMemoryTool_DisplayNames(t *testing.T) { manager := new(MockDB) - tool := NewMemoryTool(manager) + tool := New(manager) all, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -70,7 +70,7 @@ func TestMemoryTool_DisplayNames(t *testing.T) { func TestMemoryTool_HandleAddMemory(t *testing.T) { manager := new(MockDB) - tool := NewMemoryTool(manager) + tool := New(manager) manager.On("AddMemory", mock.Anything, mock.MatchedBy(func(memory database.UserMemory) bool { return memory.Memory == "test memory" @@ -86,7 +86,7 @@ func TestMemoryTool_HandleAddMemory(t *testing.T) { func TestMemoryTool_HandleAddMemoryWithCategory(t *testing.T) { manager := new(MockDB) - tool := NewMemoryTool(manager) + tool := New(manager) manager.On("AddMemory", mock.Anything, mock.MatchedBy(func(memory database.UserMemory) bool { return memory.Memory == "prefers dark mode" && memory.Category == "preference" @@ -103,7 +103,7 @@ func TestMemoryTool_HandleAddMemoryWithCategory(t *testing.T) { func TestMemoryTool_HandleGetMemories(t *testing.T) { manager := new(MockDB) - tool := NewMemoryTool(manager) + tool := New(manager) memories := []database.UserMemory{ { @@ -133,7 +133,7 @@ func TestMemoryTool_HandleGetMemories(t *testing.T) { func TestMemoryTool_HandleDeleteMemory(t *testing.T) { manager := new(MockDB) - tool := NewMemoryTool(manager) + tool := New(manager) manager.On("DeleteMemory", mock.Anything, mock.MatchedBy(func(memory database.UserMemory) bool { return memory.ID == "1" @@ -150,7 +150,7 @@ func TestMemoryTool_HandleDeleteMemory(t *testing.T) { func TestMemoryTool_HandleSearchMemories(t *testing.T) { manager := new(MockDB) - tool := NewMemoryTool(manager) + tool := New(manager) memories := []database.UserMemory{ { @@ -179,7 +179,7 @@ func TestMemoryTool_HandleSearchMemories(t *testing.T) { func TestMemoryTool_HandleUpdateMemory(t *testing.T) { manager := new(MockDB) - tool := NewMemoryTool(manager) + tool := New(manager) manager.On("UpdateMemory", mock.Anything, mock.MatchedBy(func(memory database.UserMemory) bool { return memory.ID == "42" && memory.Memory == "updated content" && memory.Category == "fact" @@ -196,7 +196,7 @@ func TestMemoryTool_HandleUpdateMemory(t *testing.T) { } func TestMemoryTool_ToolCount(t *testing.T) { - tool := NewMemoryTool(nil) + tool := New(nil) allTools, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -204,7 +204,7 @@ func TestMemoryTool_ToolCount(t *testing.T) { } func TestMemoryTool_OutputSchema(t *testing.T) { - tool := NewMemoryTool(nil) + tool := New(nil) allTools, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -216,7 +216,7 @@ func TestMemoryTool_OutputSchema(t *testing.T) { } func TestMemoryTool_ParametersAreObjects(t *testing.T) { - tool := NewMemoryTool(nil) + tool := New(nil) allTools, err := tool.Tools(t.Context()) require.NoError(t, err) diff --git a/pkg/tools/builtin/modelpicker/model_picker_test.go b/pkg/tools/builtin/modelpicker/model_picker_test.go index 5dd2ee4da..1fc2ad715 100644 --- a/pkg/tools/builtin/modelpicker/model_picker_test.go +++ b/pkg/tools/builtin/modelpicker/model_picker_test.go @@ -10,21 +10,21 @@ import ( "github.com/docker/docker-agent/pkg/tools" ) -func TestNewModelPickerTool(t *testing.T) { - tool := NewModelPickerTool([]string{"openai/gpt-4o", "anthropic/claude-sonnet-4-0"}) +func TestNew(t *testing.T) { + tool := New([]string{"openai/gpt-4o", "anthropic/claude-sonnet-4-0"}) assert.NotNil(t, tool) } func TestModelPickerTool_AllowedModels(t *testing.T) { models := []string{"openai/gpt-4o", "anthropic/claude-sonnet-4-0", "my_fast_model"} - tool := NewModelPickerTool(models) + tool := New(models) assert.Equal(t, models, tool.AllowedModels()) } func TestModelPickerTool_Tools(t *testing.T) { models := []string{"openai/gpt-4o", "anthropic/claude-sonnet-4-0"} - tool := NewModelPickerTool(models) + tool := New(models) allTools, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -67,7 +67,7 @@ func TestModelPickerTool_Tools(t *testing.T) { func TestModelPickerTool_ToolsDescriptionListsModels(t *testing.T) { models := []string{"fast_model", "smart_model", "openai/gpt-4o"} - tool := NewModelPickerTool(models) + tool := New(models) allTools, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -80,7 +80,7 @@ func TestModelPickerTool_ToolsDescriptionListsModels(t *testing.T) { } func TestModelPickerTool_DisplayNames(t *testing.T) { - tool := NewModelPickerTool([]string{"openai/gpt-4o"}) + tool := New([]string{"openai/gpt-4o"}) allTools, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -93,7 +93,7 @@ func TestModelPickerTool_DisplayNames(t *testing.T) { } func TestModelPickerTool_ParametersAreObjects(t *testing.T) { - tool := NewModelPickerTool([]string{"openai/gpt-4o"}) + tool := New([]string{"openai/gpt-4o"}) allTools, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -109,7 +109,7 @@ func TestModelPickerTool_ParametersAreObjects(t *testing.T) { } func TestModelPickerTool_ReadOnlyHint(t *testing.T) { - tool := NewModelPickerTool([]string{"openai/gpt-4o"}) + tool := New([]string{"openai/gpt-4o"}) allTools, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -121,7 +121,7 @@ func TestModelPickerTool_ReadOnlyHint(t *testing.T) { } func TestModelPickerTool_NotStartable(t *testing.T) { - tool := NewModelPickerTool([]string{"openai/gpt-4o"}) + tool := New([]string{"openai/gpt-4o"}) _, ok := any(tool).(tools.Startable) assert.False(t, ok, "Tool should not implement Startable") @@ -129,7 +129,7 @@ func TestModelPickerTool_NotStartable(t *testing.T) { func TestModelPickerTool_Instructions(t *testing.T) { models := []string{"openai/gpt-4o", "anthropic/claude-sonnet-4-0"} - tool := NewModelPickerTool(models) + tool := New(models) _, ok := any(tool).(tools.Instructable) assert.True(t, ok, "Tool should implement Instructable") @@ -144,7 +144,7 @@ func TestModelPickerTool_Instructions(t *testing.T) { } func TestModelPickerTool_SingleModel(t *testing.T) { - tool := NewModelPickerTool([]string{"openai/gpt-4o"}) + tool := New([]string{"openai/gpt-4o"}) allTools, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -159,7 +159,7 @@ func TestModelPickerTool_ManyModels(t *testing.T) { "google/gemini-2.0-flash", "my_custom_model", } - tool := NewModelPickerTool(models) + tool := New(models) assert.Equal(t, models, tool.AllowedModels()) diff --git a/pkg/tools/builtin/modelpicker/modelpicker.go b/pkg/tools/builtin/modelpicker/modelpicker.go index 95e15dfcd..d3dc9e46c 100644 --- a/pkg/tools/builtin/modelpicker/modelpicker.go +++ b/pkg/tools/builtin/modelpicker/modelpicker.go @@ -2,9 +2,11 @@ package modelpicker import ( "context" + "errors" "fmt" "strings" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/tools" ) @@ -13,15 +15,23 @@ const ( ToolNameRevertModel = "revert_model" ) -// Tool provides tools for dynamically switching the agent's model mid-conversation. -type Tool struct { +// CreateToolSet is used by the tools registry. +func CreateToolSet(toolset latest.Toolset) (tools.ToolSet, error) { + if len(toolset.Models) == 0 { + return nil, errors.New("model_picker toolset requires at least one model") + } + return New(toolset.Models), nil +} + +// ToolSet provides tools for dynamically switching the agent's model mid-conversation. +type ToolSet struct { models []string // list of available model references } // Verify interface compliance var ( - _ tools.ToolSet = (*Tool)(nil) - _ tools.Instructable = (*Tool)(nil) + _ tools.ToolSet = (*ToolSet)(nil) + _ tools.Instructable = (*ToolSet)(nil) ) // ChangeModelArgs are the arguments for the change_model tool. @@ -29,13 +39,13 @@ type ChangeModelArgs struct { Model string `json:"model" jsonschema:"The model to switch to. Must be one of the available models."` } -// NewModelPickerTool creates a new Tool with the given list of allowed models. -func NewModelPickerTool(models []string) *Tool { - return &Tool{models: models} +// New creates a new ToolSet with the given list of allowed models. +func New(models []string) *ToolSet { + return &ToolSet{models: models} } // Instructions returns guidance for the LLM on when and how to use the model picker tools. -func (t *Tool) Instructions() string { +func (t *ToolSet) Instructions() string { return "## Model Switching\n\n" + "Available models: " + strings.Join(t.models, ", ") + ".\n\n" + "Use `" + ToolNameChangeModel + "` to switch to a model better suited for the current task " + @@ -44,12 +54,12 @@ func (t *Tool) Instructions() string { } // AllowedModels returns the list of models this tool allows switching to. -func (t *Tool) AllowedModels() []string { +func (t *ToolSet) AllowedModels() []string { return t.models } // Tools returns the change_model and revert_model tool definitions. -func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { +func (t *ToolSet) Tools(context.Context) ([]tools.Tool, error) { return []tools.Tool{ { Name: ToolNameChangeModel, diff --git a/pkg/tools/builtin/openapi/openapi.go b/pkg/tools/builtin/openapi/openapi.go index e4863c11e..26d4ed14d 100644 --- a/pkg/tools/builtin/openapi/openapi.go +++ b/pkg/tools/builtin/openapi/openapi.go @@ -18,7 +18,10 @@ import ( v3 "github.com/pb33f/libopenapi/datamodel/high/v3" "go.yaml.in/yaml/v4" + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/httpclient" + "github.com/docker/docker-agent/pkg/js" "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/upstream" "github.com/docker/docker-agent/pkg/useragent" @@ -26,8 +29,18 @@ import ( const httpTimeout = 30 * time.Second -// Tool generates HTTP tools from an OpenAPI specification. -type Tool struct { +// CreateToolSet is used by the tools registry. +func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { + expander := js.NewJsExpander(runConfig.EnvProvider()) + + specURL := expander.Expand(ctx, toolset.URL, nil) + headers := expander.ExpandMap(ctx, toolset.Headers) + + return New(specURL, headers), nil +} + +// ToolSet generates HTTP tools from an OpenAPI specification. +type ToolSet struct { specURL string headers map[string]string @@ -40,20 +53,20 @@ type Tool struct { // Verify interface compliance. var ( - _ tools.ToolSet = (*Tool)(nil) - _ tools.Instructable = (*Tool)(nil) + _ tools.ToolSet = (*ToolSet)(nil) + _ tools.Instructable = (*ToolSet)(nil) ) -// NewOpenAPITool creates a new OpenAPI toolset from the given spec URL. -func NewOpenAPITool(specURL string, headers map[string]string) *Tool { - return &Tool{ +// New creates a new OpenAPI toolset from the given spec URL. +func New(specURL string, headers map[string]string) *ToolSet { + return &ToolSet{ specURL: specURL, headers: headers, } } // Instructions returns usage instructions for the OpenAPI toolset. -func (t *Tool) Instructions() string { +func (t *ToolSet) Instructions() string { return fmt.Sprintf(`## OpenAPI tools These tools were generated from the OpenAPI specification at %s. @@ -61,7 +74,7 @@ Each tool corresponds to an API endpoint. Use the tool parameters as described.` } // Tools fetches and parses the OpenAPI specification, returning a tool for each operation. -func (t *Tool) Tools(ctx context.Context) ([]tools.Tool, error) { +func (t *ToolSet) Tools(ctx context.Context) ([]tools.Tool, error) { spec, err := t.fetchSpec(ctx) if err != nil { return nil, fmt.Errorf("failed to fetch OpenAPI spec from %s: %w", t.specURL, err) @@ -71,7 +84,7 @@ func (t *Tool) Tools(ctx context.Context) ([]tools.Tool, error) { } // fetchSpec retrieves and parses the OpenAPI specification from the configured URL. -func (t *Tool) fetchSpec(ctx context.Context) (*v3.Document, error) { +func (t *ToolSet) fetchSpec(ctx context.Context) (*v3.Document, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, t.specURL, http.NoBody) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) @@ -121,7 +134,7 @@ func (t *Tool) fetchSpec(ctx context.Context) (*v3.Document, error) { } // buildTools converts an OpenAPI spec into a list of tools. -func (t *Tool) buildTools(spec *v3.Document) ([]tools.Tool, error) { +func (t *ToolSet) buildTools(spec *v3.Document) ([]tools.Tool, error) { baseURL, err := t.resolveBaseURL(spec) if err != nil { return nil, err @@ -162,7 +175,7 @@ func pathOperations(item *v3.PathItem) map[string]*v3.Operation { } // resolveBaseURL determines the base URL for API requests. -func (t *Tool) resolveBaseURL(spec *v3.Document) (string, error) { +func (t *ToolSet) resolveBaseURL(spec *v3.Document) (string, error) { if len(spec.Servers) > 0 && spec.Servers[0].URL != "" { serverURL := spec.Servers[0].URL @@ -194,7 +207,7 @@ func (t *Tool) resolveBaseURL(spec *v3.Document) (string, error) { } // operationToTool converts a single OpenAPI operation to a tool. -func (t *Tool) operationToTool(baseURL, path, method string, op *v3.Operation) tools.Tool { +func (t *ToolSet) operationToTool(baseURL, path, method string, op *v3.Operation) tools.Tool { name := operationToolName(path, method, op) desc := operationDescription(path, method, op) schema := operationSchema(op) diff --git a/pkg/tools/builtin/openapi/openapi_test.go b/pkg/tools/builtin/openapi/openapi_test.go index 6e7402760..d56b02e8a 100644 --- a/pkg/tools/builtin/openapi/openapi_test.go +++ b/pkg/tools/builtin/openapi/openapi_test.go @@ -19,9 +19,9 @@ import ( // dial-time protection so tests can talk to httptest.NewServer (which // binds to 127.0.0.1). It is defined in a *_test.go file so it is not // compiled into release binaries. Production callers must use -// [NewOpenAPITool]. -func newOpenAPIToolForTest(specURL string, headers map[string]string) *Tool { - t := NewOpenAPITool(specURL, headers) +// [New]. +func newOpenAPIToolForTest(specURL string, headers map[string]string) *ToolSet { + t := New(specURL, headers) t.unsafe = true return t } @@ -386,7 +386,7 @@ func TestOpenAPITool_ErrorResponse(t *testing.T) { func TestOpenAPITool_InvalidSpecURL(t *testing.T) { t.Parallel() - _, err := NewOpenAPITool("http://127.0.0.1:1/nonexistent", nil).Tools(t.Context()) + _, err := New("http://127.0.0.1:1/nonexistent", nil).Tools(t.Context()) require.Error(t, err) assert.Contains(t, err.Error(), "failed to fetch OpenAPI spec") } @@ -405,7 +405,7 @@ func TestOpenAPITool_RejectsLocalSpecURL(t *testing.T) { for _, target := range tests { t.Run(target, func(t *testing.T) { t.Parallel() - _, err := NewOpenAPITool(target, nil).Tools(t.Context()) + _, err := New(target, nil).Tools(t.Context()) require.Error(t, err) assert.Contains(t, err.Error(), "non-public address") }) @@ -445,7 +445,7 @@ func TestOpenAPITool_RejectsLocalSpecServerURL(t *testing.T) { // Even though the spec was fetched in unsafe mode, the generated // handler still inherits the unsafe flag — so for the real safety // guarantee we re-run the operation through the production path. - prod := NewOpenAPITool(specServer.URL+"/openapi.json", nil) + prod := New(specServer.URL+"/openapi.json", nil) prodTools, err := prod.Tools(t.Context()) require.Error(t, err, "production constructor must refuse a loopback spec server") assert.Nil(t, prodTools) @@ -454,7 +454,7 @@ func TestOpenAPITool_RejectsLocalSpecServerURL(t *testing.T) { func TestOpenAPITool_Instructions(t *testing.T) { t.Parallel() - instructions := NewOpenAPITool("https://example.com/openapi.json", nil).Instructions() + instructions := New("https://example.com/openapi.json", nil).Instructions() assert.Contains(t, instructions, "OpenAPI") assert.Contains(t, instructions, "https://example.com/openapi.json") diff --git a/pkg/tools/builtin/rag/rag.go b/pkg/tools/builtin/rag/rag.go index f3a64066d..743773b0a 100644 --- a/pkg/tools/builtin/rag/rag.go +++ b/pkg/tools/builtin/rag/rag.go @@ -9,16 +9,42 @@ import ( "log/slog" "slices" + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/rag" ragtypes "github.com/docker/docker-agent/pkg/rag/types" "github.com/docker/docker-agent/pkg/tools" ) +// CreateToolSet is used by the tools registry. +func CreateToolSet(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { + if toolset.RAGConfig == nil { + return nil, errors.New("rag toolset requires either a rag_config block or a ref") + } + + ragName := cmp.Or(toolset.Name, "rag") + + mgr, err := rag.NewManager(ctx, ragName, toolset.RAGConfig, rag.ManagersBuildConfig{ + ParentDir: parentDir, + ModelsGateway: runConfig.ModelsGateway, + Env: runConfig.EnvProvider(), + Models: runConfig.Models, + Providers: runConfig.Providers, + RuntimeConfig: runConfig, + }) + if err != nil { + return nil, fmt.Errorf("failed to create RAG manager: %w", err) + } + + toolName := cmp.Or(mgr.ToolName(), ragName) + return New(mgr, toolName), nil +} + // EventCallback is called to forward RAG manager events during initialization. type EventCallback func(event ragtypes.Event) -// Tool provides document querying capabilities for a single RAG source. -type Tool struct { +// ToolSet provides document querying capabilities for a single RAG source. +type ToolSet struct { manager *rag.Manager toolName string eventCallback EventCallback @@ -26,33 +52,33 @@ type Tool struct { // Verify interface compliance. var ( - _ tools.ToolSet = (*Tool)(nil) - _ tools.Instructable = (*Tool)(nil) - _ tools.Startable = (*Tool)(nil) + _ tools.ToolSet = (*ToolSet)(nil) + _ tools.Instructable = (*ToolSet)(nil) + _ tools.Startable = (*ToolSet)(nil) ) -// NewRAGTool creates a new RAG tool for a single RAG manager. -func NewRAGTool(manager *rag.Manager, toolName string) *Tool { - return &Tool{ +// New creates a new RAG toolset for a single RAG manager. +func New(manager *rag.Manager, toolName string) *ToolSet { + return &ToolSet{ manager: manager, toolName: toolName, } } // Name returns the tool name for this RAG source. -func (t *Tool) Name() string { +func (t *ToolSet) Name() string { return t.toolName } // SetEventCallback sets a callback to receive RAG manager events during // initialization. Must be called before Start(). -func (t *Tool) SetEventCallback(cb EventCallback) { +func (t *ToolSet) SetEventCallback(cb EventCallback) { t.eventCallback = cb } // Start initializes the RAG manager (indexes documents) and starts a // file watcher for incremental updates. -func (t *Tool) Start(ctx context.Context) error { +func (t *ToolSet) Start(ctx context.Context) error { if t.manager == nil { return nil } @@ -75,7 +101,7 @@ func (t *Tool) Start(ctx context.Context) error { } // Stop closes the RAG manager and releases resources. -func (t *Tool) Stop(_ context.Context) error { +func (t *ToolSet) Stop(_ context.Context) error { if t.manager == nil { return nil } @@ -83,7 +109,7 @@ func (t *Tool) Stop(_ context.Context) error { } // forwardEvents reads events from the RAG manager and forwards them via the callback. -func (t *Tool) forwardEvents(ctx context.Context) { +func (t *ToolSet) forwardEvents(ctx context.Context) { for { select { case <-ctx.Done(): @@ -97,7 +123,7 @@ func (t *Tool) forwardEvents(ctx context.Context) { } } -func (t *Tool) Instructions() string { +func (t *ToolSet) Instructions() string { if t.manager != nil { if instruction := t.manager.ToolInstruction(); instruction != "" { return instruction @@ -118,7 +144,7 @@ type queryResult struct { ChunkIndex int `json:"chunk_index" jsonschema:"Index of the chunk within the source document"` } -func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { +func (t *ToolSet) Tools(context.Context) ([]tools.Tool, error) { var description string if t.manager != nil { description = t.manager.Description() @@ -141,7 +167,7 @@ func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { }}, nil } -func (t *Tool) handleQueryRAG(ctx context.Context, args queryRAGArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) handleQueryRAG(ctx context.Context, args queryRAGArgs) (*tools.ToolCallResult, error) { if args.Query == "" { return nil, errors.New("query cannot be empty") } diff --git a/pkg/tools/builtin/rag/rag_test.go b/pkg/tools/builtin/rag/rag_test.go index f99af9664..162bbf39f 100644 --- a/pkg/tools/builtin/rag/rag_test.go +++ b/pkg/tools/builtin/rag/rag_test.go @@ -29,7 +29,7 @@ func TestRAGTool_ToolName(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tool := &Tool{ + tool := &ToolSet{ toolName: tt.toolName, manager: nil, } @@ -44,7 +44,7 @@ func TestRAGTool_ToolName(t *testing.T) { } func TestRAGTool_DefaultDescription(t *testing.T) { - tool := &Tool{ + tool := &ToolSet{ toolName: "test_docs", manager: nil, } diff --git a/pkg/tools/builtin/shell/script_shell.go b/pkg/tools/builtin/shell/script_shell.go index bc002ad5a..d2a503053 100644 --- a/pkg/tools/builtin/shell/script_shell.go +++ b/pkg/tools/builtin/shell/script_shell.go @@ -4,6 +4,7 @@ import ( "cmp" "context" "encoding/json" + "errors" "fmt" "maps" "os" @@ -11,30 +12,46 @@ import ( "slices" "strings" + "github.com/docker/docker-agent/pkg/config" "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/shellpath" "github.com/docker/docker-agent/pkg/tools" ) -type ScriptShellTool struct { +// CreateScriptToolSet is used by the tools registry. +func CreateScriptToolSet(ctx context.Context, toolset latest.Toolset, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { + if len(toolset.Shell) == 0 { + return nil, errors.New("shell is required for script toolset") + } + + env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), runConfig.EnvProvider()) + if err != nil { + return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err) + } + env = append(env, os.Environ()...) + return NewScript(toolset.Shell, env) +} + +type ScriptToolSet struct { shellTools map[string]latest.ScriptShellToolConfig env []string } // Verify interface compliance var ( - _ tools.ToolSet = (*ScriptShellTool)(nil) - _ tools.Instructable = (*ScriptShellTool)(nil) + _ tools.ToolSet = (*ScriptToolSet)(nil) + _ tools.Instructable = (*ScriptToolSet)(nil) ) -func NewScriptShellTool(shellTools map[string]latest.ScriptShellToolConfig, env []string) (*ScriptShellTool, error) { +func NewScript(shellTools map[string]latest.ScriptShellToolConfig, env []string) (*ScriptToolSet, error) { for toolName, tool := range shellTools { if err := validateConfig(toolName, tool); err != nil { return nil, err } } - return &ScriptShellTool{ + return &ScriptToolSet{ shellTools: shellTools, env: env, }, nil @@ -71,7 +88,7 @@ func validateConfig(toolName string, tool latest.ScriptShellToolConfig) error { return nil } -func (t *ScriptShellTool) Instructions() string { +func (t *ScriptToolSet) Instructions() string { var sb strings.Builder sb.WriteString("## Custom Shell Tools\n\n") @@ -97,7 +114,7 @@ func (t *ScriptShellTool) Instructions() string { return sb.String() } -func (t *ScriptShellTool) Tools(context.Context) ([]tools.Tool, error) { +func (t *ScriptToolSet) Tools(context.Context) ([]tools.Tool, error) { var toolsList []tools.Tool for name, toolConfig := range t.shellTools { @@ -130,7 +147,7 @@ func (t *ScriptShellTool) Tools(context.Context) ([]tools.Tool, error) { return toolsList, nil } -func (t *ScriptShellTool) execute(ctx context.Context, toolConfig *latest.ScriptShellToolConfig, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { +func (t *ScriptToolSet) execute(ctx context.Context, toolConfig *latest.ScriptShellToolConfig, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { var params map[string]any if toolCall.Function.Arguments != "" { if err := json.Unmarshal([]byte(toolCall.Function.Arguments), ¶ms); err != nil { diff --git a/pkg/tools/builtin/shell/script_shell_test.go b/pkg/tools/builtin/shell/script_shell_test.go index bb261d65a..4c6bc4017 100644 --- a/pkg/tools/builtin/shell/script_shell_test.go +++ b/pkg/tools/builtin/shell/script_shell_test.go @@ -12,8 +12,8 @@ import ( "github.com/docker/docker-agent/pkg/tools" ) -func TestNewScriptShellTool_Empty(t *testing.T) { - tool, err := NewScriptShellTool(nil, nil) +func TestNewScript_Empty(t *testing.T) { + tool, err := NewScript(nil, nil) require.NoError(t, err) allTools, err := tool.Tools(t.Context()) @@ -21,14 +21,14 @@ func TestNewScriptShellTool_Empty(t *testing.T) { assert.Empty(t, allTools) } -func TestNewScriptShellTool_ToolNoArg(t *testing.T) { +func TestNewScript_ToolNoArg(t *testing.T) { shellTools := map[string]latest.ScriptShellToolConfig{ "get_ip": { Description: "Get public IP", }, } - tool, err := NewScriptShellTool(shellTools, nil) + tool, err := NewScript(shellTools, nil) require.NoError(t, err) allTools, err := tool.Tools(t.Context()) @@ -43,7 +43,7 @@ func TestNewScriptShellTool_ToolNoArg(t *testing.T) { }`, string(schema)) } -func TestNewScriptShellTool_Tool(t *testing.T) { +func TestNewScript_Tool(t *testing.T) { shellTools := map[string]latest.ScriptShellToolConfig{ "github_user_repos": { Description: "List GitHub repositories of the provided user", @@ -57,7 +57,7 @@ func TestNewScriptShellTool_Tool(t *testing.T) { }, } - tool, err := NewScriptShellTool(shellTools, nil) + tool, err := NewScript(shellTools, nil) require.NoError(t, err) allTools, err := tool.Tools(t.Context()) @@ -78,7 +78,7 @@ func TestNewScriptShellTool_Tool(t *testing.T) { }`, string(schema)) } -func TestNewScriptShellTool_Typo(t *testing.T) { +func TestNewScript_Typo(t *testing.T) { shellTools := map[string]latest.ScriptShellToolConfig{ "docker_images": { Description: "List running Docker containers", @@ -93,12 +93,12 @@ func TestNewScriptShellTool_Typo(t *testing.T) { }, } - tool, err := NewScriptShellTool(shellTools, nil) + tool, err := NewScript(shellTools, nil) require.Nil(t, tool) require.ErrorContains(t, err, "tool 'docker_images' uses undefined args: [image]") } -func TestNewScriptShellTool_MissingRequired(t *testing.T) { +func TestNewScript_MissingRequired(t *testing.T) { shellTools := map[string]latest.ScriptShellToolConfig{ "docker_images": { Description: "List running Docker containers", @@ -113,12 +113,12 @@ func TestNewScriptShellTool_MissingRequired(t *testing.T) { }, } - tool, err := NewScriptShellTool(shellTools, nil) + tool, err := NewScript(shellTools, nil) require.Nil(t, tool) require.ErrorContains(t, err, "tool 'docker_images' has required arg 'img' which is not defined in args") } -func TestNewScriptShellTool_NumberArg(t *testing.T) { +func TestNewScript_NumberArg(t *testing.T) { shellTools := map[string]latest.ScriptShellToolConfig{ "repeat": { Description: "Repeat a message N times", @@ -137,7 +137,7 @@ func TestNewScriptShellTool_NumberArg(t *testing.T) { }, } - tool, err := NewScriptShellTool(shellTools, os.Environ()) + tool, err := NewScript(shellTools, os.Environ()) require.NoError(t, err) allTools, err := tool.Tools(t.Context()) @@ -172,7 +172,7 @@ func TestScriptShellTool_DropsUndeclaredArgs(t *testing.T) { }, } - tool, err := NewScriptShellTool(shellTools, []string{}) + tool, err := NewScript(shellTools, []string{}) require.NoError(t, err) allTools, err := tool.Tools(t.Context()) @@ -206,7 +206,7 @@ func TestScriptShellTool_RejectsNULInValue(t *testing.T) { }, } - tool, err := NewScriptShellTool(shellTools, []string{}) + tool, err := NewScript(shellTools, []string{}) require.NoError(t, err) allTools, err := tool.Tools(t.Context()) @@ -223,7 +223,7 @@ func TestScriptShellTool_RejectsNULInValue(t *testing.T) { assert.Contains(t, result.Output, "NUL byte") } -func TestNewScriptShellTool_ArgWithoutType(t *testing.T) { +func TestNewScript_ArgWithoutType(t *testing.T) { shellTools := map[string]latest.ScriptShellToolConfig{ "greet": { Description: "Greet someone", @@ -237,7 +237,7 @@ func TestNewScriptShellTool_ArgWithoutType(t *testing.T) { }, } - tool, err := NewScriptShellTool(shellTools, nil) + tool, err := NewScript(shellTools, nil) require.NoError(t, err) allTools, err := tool.Tools(t.Context()) diff --git a/pkg/tools/builtin/shell/shell.go b/pkg/tools/builtin/shell/shell.go index 116324438..cf88177d6 100644 --- a/pkg/tools/builtin/shell/shell.go +++ b/pkg/tools/builtin/shell/shell.go @@ -18,6 +18,8 @@ import ( "github.com/docker/docker-agent/pkg/concurrent" "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/shellpath" "github.com/docker/docker-agent/pkg/tools" ) @@ -30,16 +32,16 @@ const ( ToolNameStopBackgroundJob = "stop_background_job" ) -// Tool provides shell command execution capabilities. -type Tool struct { +// ToolSet provides shell command execution capabilities. +type ToolSet struct { handler *shellHandler } // Verify interface compliance var ( - _ tools.ToolSet = (*Tool)(nil) - _ tools.Startable = (*Tool)(nil) - _ tools.Instructable = (*Tool)(nil) + _ tools.ToolSet = (*ToolSet)(nil) + _ tools.Startable = (*ToolSet)(nil) + _ tools.Instructable = (*ToolSet)(nil) ) type shellHandler struct { @@ -450,8 +452,23 @@ func reapSpawnedChild(cmd *exec.Cmd, pg *processGroup) { } } -// NewShellTool creates a new shell tool. -func NewShellTool(env []string, runConfig *config.RuntimeConfig) *Tool { +// CreateToolSet is used by the tools registry. +func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { + env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), runConfig.EnvProvider()) + if err != nil { + return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err) + } + // Re-append os.Environ() after expansion so spawned processes inherit the + // host environment. EnvProvider is used only to expand ${...} references + // in toolset.Env; the subprocess still needs access to the full environment. + + env = append(env, os.Environ()...) + + return New(env, runConfig), nil +} + +// New creates a new shell toolset. +func New(env []string, runConfig *config.RuntimeConfig) *ToolSet { shell, argsPrefix := detectShell() handler := &shellHandler{ @@ -463,7 +480,7 @@ func NewShellTool(env []string, runConfig *config.RuntimeConfig) *Tool { workingDir: runConfig.WorkingDir, } - return &Tool{handler: handler} + return &ToolSet{handler: handler} } // detectShell returns the appropriate shell and arguments based on the platform. @@ -502,7 +519,7 @@ func formatCommandOutput(timeoutCtx, ctx context.Context, err error, rawOutput s return cmp.Or(strings.TrimSpace(output), "") } -func (t *Tool) Instructions() string { +func (t *ToolSet) Instructions() string { return `## Shell Tools - Each call runs in a fresh shell session — no state persists between calls @@ -516,7 +533,7 @@ func (t *Tool) Instructions() string { Use run_background_job for long-running processes (servers, watchers). Output capped at 10MB per job. All jobs auto-terminate when the agent stops.` } -func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { +func (t *ToolSet) Tools(context.Context) ([]tools.Tool, error) { return []tools.Tool{ { Name: ToolNameShell, @@ -570,11 +587,11 @@ func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { }, nil } -func (t *Tool) Start(context.Context) error { +func (t *ToolSet) Start(context.Context) error { return nil } -func (t *Tool) Stop(context.Context) error { +func (t *ToolSet) Stop(context.Context) error { // Terminate all running background jobs t.handler.jobs.Range(func(_ string, job *backgroundJob) bool { if job.status.CompareAndSwap(statusRunning, statusStopped) { diff --git a/pkg/tools/builtin/shell/shell_test.go b/pkg/tools/builtin/shell/shell_test.go index 2b98c5dc2..b9b9f0495 100644 --- a/pkg/tools/builtin/shell/shell_test.go +++ b/pkg/tools/builtin/shell/shell_test.go @@ -15,16 +15,16 @@ import ( "github.com/docker/docker-agent/pkg/tools" ) -func TestNewShellTool(t *testing.T) { +func TestNew(t *testing.T) { t.Setenv("SHELL", "/bin/bash") - tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) + tool := New(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) assert.NotNil(t, tool) assert.NotNil(t, tool.handler) assert.Equal(t, "/bin/bash", tool.handler.shell) t.Setenv("SHELL", "") - tool = NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) + tool = New(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) assert.NotNil(t, tool) assert.NotNil(t, tool.handler) @@ -32,7 +32,7 @@ func TestNewShellTool(t *testing.T) { } func TestShellTool_HandlerEcho(t *testing.T) { - tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) + tool := New(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) result, err := tool.handler.RunShell(t.Context(), RunShellArgs{ Cmd: "echo 'hello world'", @@ -43,7 +43,7 @@ func TestShellTool_HandlerEcho(t *testing.T) { } func TestShellTool_HandlerWithCwd(t *testing.T) { - tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) + tool := New(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) tmpDir := t.TempDir() result, err := tool.handler.RunShell(t.Context(), RunShellArgs{ @@ -136,7 +136,7 @@ func TestRunShellBackgroundArgs_UnmarshalJSON_AcceptsCmdAndCommand(t *testing.T) // use "command" instead of "cmd" must execute normally rather than return // the missing-parameter error. func TestShellTool_HandlerAcceptsCommandAlias(t *testing.T) { - tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) + tool := New(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) var params RunShellArgs require.NoError(t, json.Unmarshal([]byte(`{"command":"echo hello-from-alias"}`), ¶ms)) @@ -147,7 +147,7 @@ func TestShellTool_HandlerAcceptsCommandAlias(t *testing.T) { } func TestShellTool_HandlerMissingCmdReturnsActionableError(t *testing.T) { - tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) + tool := New(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) result, err := tool.handler.RunShell(t.Context(), RunShellArgs{}) require.NoError(t, err) @@ -156,7 +156,7 @@ func TestShellTool_HandlerMissingCmdReturnsActionableError(t *testing.T) { } func TestShellTool_HandlerError(t *testing.T) { - tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) + tool := New(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) result, err := tool.handler.RunShell(t.Context(), RunShellArgs{ Cmd: "command_that_does_not_exist", @@ -167,7 +167,7 @@ func TestShellTool_HandlerError(t *testing.T) { } func TestShellTool_OutputSchema(t *testing.T) { - tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) + tool := New(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) allTools, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -179,7 +179,7 @@ func TestShellTool_OutputSchema(t *testing.T) { } func TestShellTool_ParametersAreObjects(t *testing.T) { - tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) + tool := New(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) allTools, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -194,7 +194,7 @@ func TestShellTool_ParametersAreObjects(t *testing.T) { // Minimal tests for background job features func TestShellTool_RunBackgroundJob(t *testing.T) { - tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) + tool := New(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) err := tool.Start(t.Context()) require.NoError(t, err) t.Cleanup(func() { @@ -207,7 +207,7 @@ func TestShellTool_RunBackgroundJob(t *testing.T) { } func TestShellTool_ListBackgroundJobs(t *testing.T) { - tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) + tool := New(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) err := tool.Start(t.Context()) require.NoError(t, err) t.Cleanup(func() { @@ -229,7 +229,7 @@ func TestShellTool_ListBackgroundJobs(t *testing.T) { func TestShellTool_Instructions(t *testing.T) { t.Parallel() - tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) + tool := New(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) instructions := tool.Instructions() @@ -270,7 +270,7 @@ func TestShellTool_RelativeCwdResolvesAgainstWorkingDir(t *testing.T) { subdir := workingDir + "/subdir" require.NoError(t, os.Mkdir(subdir, 0o755)) - tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: workingDir}}) + tool := New(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: workingDir}}) result, err := tool.handler.RunShell(t.Context(), RunShellArgs{ Cmd: "pwd", @@ -296,7 +296,7 @@ func TestShellTool_BackgroundedChildDoesNotBlockReturn(t *testing.T) { t.Skip("POSIX shell backgrounding semantics; skipped on Windows") } - tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) + tool := New(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) start := time.Now() result, err := tool.handler.RunShell(t.Context(), RunShellArgs{ @@ -325,7 +325,7 @@ func TestShellTool_DetachedBackgroundedChildDoesNotBlockReturn(t *testing.T) { t.Skip("setsid not available") } - tool := NewShellTool(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) + tool := New(nil, &config.RuntimeConfig{Config: config.Config{WorkingDir: t.TempDir()}}) done := make(chan struct{}) var result *tools.ToolCallResult diff --git a/pkg/tools/builtin/skills/skills.go b/pkg/tools/builtin/skills/skills.go index 970e590c7..e59936283 100644 --- a/pkg/tools/builtin/skills/skills.go +++ b/pkg/tools/builtin/skills/skills.go @@ -18,31 +18,31 @@ const ( ) var ( - _ tools.ToolSet = (*Toolset)(nil) - _ tools.Instructable = (*Toolset)(nil) + _ tools.ToolSet = (*ToolSet)(nil) + _ tools.Instructable = (*ToolSet)(nil) ) -// Toolset provides the read_skill and read_skill_file tools that let an +// ToolSet provides the read_skill and read_skill_file tools that let an // agent load skill content and supporting resources by name. It hides whether // a skill is local or remote — the agent just sees a name and description. -type Toolset struct { +type ToolSet struct { skills []skills.Skill workingDir string } -func NewSkillsToolset(loadedSkills []skills.Skill, workingDir string) *Toolset { - return &Toolset{ +func New(loadedSkills []skills.Skill, workingDir string) *ToolSet { + return &ToolSet{ skills: loadedSkills, workingDir: workingDir, } } // Skills returns the loaded skills (used by the app layer for slash commands). -func (s *Toolset) Skills() []skills.Skill { +func (s *ToolSet) Skills() []skills.Skill { return s.skills } -func (s *Toolset) findSkill(name string) *skills.Skill { +func (s *ToolSet) findSkill(name string) *skills.Skill { for i := range s.skills { if s.skills[i].Name == name { return &s.skills[i] @@ -52,7 +52,7 @@ func (s *Toolset) findSkill(name string) *skills.Skill { } // FindSkill returns the skill with the given name, or nil if not found. -func (s *Toolset) FindSkill(name string) *skills.Skill { +func (s *ToolSet) FindSkill(name string) *skills.Skill { return s.findSkill(name) } @@ -60,7 +60,7 @@ func (s *Toolset) FindSkill(name string) *skills.Skill { // For local skills, it expands any !`command` patterns in the content by // executing the commands and replacing the patterns with their stdout output. // Command expansion is disabled for remote skills to prevent arbitrary code execution. -func (s *Toolset) ReadSkillContent(ctx context.Context, name string) (string, error) { +func (s *ToolSet) ReadSkillContent(ctx context.Context, name string) (string, error) { skill := s.findSkill(name) if skill == nil { return "", fmt.Errorf("skill %q not found", name) @@ -80,7 +80,7 @@ func (s *Toolset) ReadSkillContent(ctx context.Context, name string) (string, er // ReadSkillFile returns the content of a supporting file within a skill. // The path is relative to the skill's base directory (e.g. "references/FORMS.md"). -func (s *Toolset) ReadSkillFile(skillName, relativePath string) (string, error) { +func (s *ToolSet) ReadSkillFile(skillName, relativePath string) (string, error) { skill := s.findSkill(skillName) if skill == nil { return "", fmt.Errorf("skill %q not found", skillName) @@ -134,7 +134,7 @@ type readSkillFileArgs struct { Path string `json:"path" jsonschema:"The relative path to the file within the skill (e.g. references/FORMS.md)"` } -func (s *Toolset) handleReadSkill(ctx context.Context, args readSkillArgs) (*tools.ToolCallResult, error) { +func (s *ToolSet) handleReadSkill(ctx context.Context, args readSkillArgs) (*tools.ToolCallResult, error) { content, err := s.ReadSkillContent(ctx, args.Name) if err != nil { return tools.ResultError(err.Error()), nil @@ -142,7 +142,7 @@ func (s *Toolset) handleReadSkill(ctx context.Context, args readSkillArgs) (*too return tools.ResultSuccess(content), nil } -func (s *Toolset) handleReadSkillFile(_ context.Context, args readSkillFileArgs) (*tools.ToolCallResult, error) { +func (s *ToolSet) handleReadSkillFile(_ context.Context, args readSkillFileArgs) (*tools.ToolCallResult, error) { content, err := s.ReadSkillFile(args.SkillName, args.Path) if err != nil { return tools.ResultError(err.Error()), nil @@ -151,7 +151,7 @@ func (s *Toolset) handleReadSkillFile(_ context.Context, args readSkillFileArgs) } // hasFiles reports whether any loaded skill has supporting files beyond SKILL.md. -func (s *Toolset) hasFiles() bool { +func (s *ToolSet) hasFiles() bool { for _, skill := range s.skills { if len(skill.Files) > 1 { return true @@ -161,7 +161,7 @@ func (s *Toolset) hasFiles() bool { } // hasForkSkills reports whether any loaded skill uses context: fork. -func (s *Toolset) hasForkSkills() bool { +func (s *ToolSet) hasForkSkills() bool { for i := range s.skills { if s.skills[i].IsFork() { return true @@ -170,7 +170,7 @@ func (s *Toolset) hasForkSkills() bool { return false } -func (s *Toolset) Instructions() string { +func (s *ToolSet) Instructions() string { if len(s.skills) == 0 { return "" } @@ -258,7 +258,7 @@ type PreparedSkillFork struct { // skill not configured for fork mode, content read failure). The caller is // responsible for the runtime-specific orchestration (sub-session creation, // tracing, event forwarding). -func (s *Toolset) PrepareForkSubSession(ctx context.Context, args RunSkillArgs) (*PreparedSkillFork, *tools.ToolCallResult) { +func (s *ToolSet) PrepareForkSubSession(ctx context.Context, args RunSkillArgs) (*PreparedSkillFork, *tools.ToolCallResult) { skill := s.findSkill(args.Name) if skill == nil { return nil, tools.ResultError(fmt.Sprintf("skill %q not found", args.Name)) @@ -284,7 +284,7 @@ func (s *Toolset) PrepareForkSubSession(ctx context.Context, args RunSkillArgs) }, nil } -func (s *Toolset) Tools(context.Context) ([]tools.Tool, error) { +func (s *ToolSet) Tools(context.Context) ([]tools.Tool, error) { if len(s.skills) == 0 { return nil, nil } diff --git a/pkg/tools/builtin/skills/skills_test.go b/pkg/tools/builtin/skills/skills_test.go index c9ba10a68..e95b28a2a 100644 --- a/pkg/tools/builtin/skills/skills_test.go +++ b/pkg/tools/builtin/skills/skills_test.go @@ -17,7 +17,7 @@ func TestSkillsToolset_ReadSkillContent_Local(t *testing.T) { skillFile := filepath.Join(tmpDir, "SKILL.md") require.NoError(t, os.WriteFile(skillFile, []byte("# Local Skill\nDo the thing."), 0o644)) - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "local-skill", Description: "A local skill", FilePath: skillFile, BaseDir: tmpDir}, }, "") @@ -27,7 +27,7 @@ func TestSkillsToolset_ReadSkillContent_Local(t *testing.T) { } func TestSkillsToolset_ReadSkillContent_NotFound(t *testing.T) { - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "exists", Description: "Exists", FilePath: "/tmp/nonexistent"}, }, "") @@ -42,7 +42,7 @@ func TestSkillsToolset_ReadSkillFile(t *testing.T) { require.NoError(t, os.MkdirAll(filepath.Join(tmpDir, "references"), 0o755)) require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "references", "FORMS.md"), []byte("# Forms Reference"), 0o644)) - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ { Name: "my-skill", Description: "My skill", FilePath: filepath.Join(tmpDir, "SKILL.md"), BaseDir: tmpDir, Files: []string{"SKILL.md", "references/FORMS.md"}, @@ -58,7 +58,7 @@ func TestSkillsToolset_ReadSkillFile_PathTraversal(t *testing.T) { tmpDir := t.TempDir() require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "SKILL.md"), []byte("# Main"), 0o644)) - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "my-skill", Description: "My skill", FilePath: filepath.Join(tmpDir, "SKILL.md"), BaseDir: tmpDir}, }, "") @@ -72,7 +72,7 @@ func TestSkillsToolset_ReadSkillFile_PathTraversal(t *testing.T) { } func TestSkillsToolset_ReadSkillFile_SkillNotFound(t *testing.T) { - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "exists", Description: "Exists", FilePath: "/tmp/test"}, }, "") @@ -82,7 +82,7 @@ func TestSkillsToolset_ReadSkillFile_SkillNotFound(t *testing.T) { } func TestSkillsToolset_Instructions(t *testing.T) { - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "skill-a", Description: "Does A"}, {Name: "skill-b", Description: "Does B", Files: []string{"SKILL.md", "references/HELP.md"}}, }, "") @@ -102,7 +102,7 @@ func TestSkillsToolset_Instructions(t *testing.T) { } func TestSkillsToolset_Instructions_NoFiles(t *testing.T) { - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "simple", Description: "Simple skill"}, }, "") @@ -114,15 +114,15 @@ func TestSkillsToolset_Instructions_NoFiles(t *testing.T) { } func TestSkillsToolset_Instructions_Empty(t *testing.T) { - st := NewSkillsToolset(nil, "") + st := New(nil, "") assert.Empty(t, st.Instructions()) - st = NewSkillsToolset([]skills.Skill{}, "") + st = New([]skills.Skill{}, "") assert.Empty(t, st.Instructions()) } func TestSkillsToolset_Tools_WithFiles(t *testing.T) { - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "test", Description: "Test skill", Files: []string{"SKILL.md", "references/HELP.md"}}, }, "") @@ -135,7 +135,7 @@ func TestSkillsToolset_Tools_WithFiles(t *testing.T) { } func TestSkillsToolset_Tools_WithoutFiles(t *testing.T) { - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "test", Description: "Test skill"}, }, "") @@ -147,7 +147,7 @@ func TestSkillsToolset_Tools_WithoutFiles(t *testing.T) { } func TestSkillsToolset_Tools_Empty(t *testing.T) { - st := NewSkillsToolset(nil, "") + st := New(nil, "") tools, err := st.Tools(t.Context()) require.NoError(t, err) @@ -159,7 +159,7 @@ func TestSkillsToolset_Skills(t *testing.T) { {Name: "a", Description: "A"}, {Name: "b", Description: "B"}, } - st := NewSkillsToolset(input, "") + st := New(input, "") assert.Equal(t, input, st.Skills()) } @@ -169,7 +169,7 @@ func TestSkillsToolset_HandleReadSkill(t *testing.T) { skillFile := filepath.Join(tmpDir, "SKILL.md") require.NoError(t, os.WriteFile(skillFile, []byte("skill instructions"), 0o644)) - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "test-skill", Description: "Test", FilePath: skillFile, BaseDir: tmpDir}, }, "") @@ -180,7 +180,7 @@ func TestSkillsToolset_HandleReadSkill(t *testing.T) { } func TestSkillsToolset_HandleReadSkill_NotFound(t *testing.T) { - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "exists", Description: "Exists", FilePath: "/tmp/test"}, }, "") @@ -196,7 +196,7 @@ func TestSkillsToolset_HandleReadSkillFile(t *testing.T) { require.NoError(t, os.MkdirAll(filepath.Join(tmpDir, "scripts"), 0o755)) require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "scripts", "deploy.sh"), []byte("#!/bin/bash\necho deploy"), 0o644)) - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ { Name: "my-skill", Description: "My skill", FilePath: filepath.Join(tmpDir, "SKILL.md"), BaseDir: tmpDir, Files: []string{"SKILL.md", "scripts/deploy.sh"}, @@ -213,7 +213,7 @@ func TestSkillsToolset_HandleReadSkillFile_PathTraversal(t *testing.T) { tmpDir := t.TempDir() require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "SKILL.md"), []byte("# Main"), 0o644)) - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "my-skill", Description: "My skill", FilePath: filepath.Join(tmpDir, "SKILL.md"), BaseDir: tmpDir}, }, "") @@ -233,7 +233,7 @@ func TestSkillsToolset_ReadSkillContent_ExpandsCommands(t *testing.T) { content := "# Skill\nBranch: !`echo main`\nDone." require.NoError(t, os.WriteFile(skillFile, []byte(content), 0o644)) - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "expand-skill", Description: "Expands commands", FilePath: skillFile, BaseDir: tmpDir, Local: true}, }, tmpDir) @@ -257,7 +257,7 @@ func TestSkillsToolset_ReadSkillContent_ExpandsScript(t *testing.T) { content := "Data: !`./gather.sh`" require.NoError(t, os.WriteFile(skillFile, []byte(content), 0o644)) - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "script-skill", Description: "Runs scripts", FilePath: skillFile, BaseDir: tmpDir, Local: true}, }, tmpDir) @@ -272,7 +272,7 @@ func TestSkillsToolset_ReadSkillContent_RemoteSkillSkipsExpansion(t *testing.T) content := "Info: !`echo should-not-run`" require.NoError(t, os.WriteFile(skillFile, []byte(content), 0o644)) - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "remote-skill", Description: "Remote", FilePath: skillFile, BaseDir: tmpDir, Local: false}, }, "") @@ -282,7 +282,7 @@ func TestSkillsToolset_ReadSkillContent_RemoteSkillSkipsExpansion(t *testing.T) } func TestSkillsToolset_FindSkill(t *testing.T) { - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "alpha", Description: "Alpha skill"}, {Name: "beta", Description: "Beta skill"}, }, "") @@ -299,7 +299,7 @@ func TestSkillsToolset_FindSkill(t *testing.T) { } func TestSkillsToolset_Instructions_ForkSkills(t *testing.T) { - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "inline-skill", Description: "Runs inline"}, {Name: "fork-skill", Description: "Runs as sub-agent", Context: "fork"}, }, "") @@ -320,7 +320,7 @@ func TestSkillsToolset_Instructions_ForkSkills(t *testing.T) { } func TestSkillsToolset_Instructions_NoForkSkills(t *testing.T) { - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "normal", Description: "Normal skill"}, }, "") @@ -333,7 +333,7 @@ func TestSkillsToolset_Instructions_NoForkSkills(t *testing.T) { } func TestSkillsToolset_Tools_WithForkSkills(t *testing.T) { - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "inline", Description: "Inline skill"}, {Name: "forked", Description: "Forked skill", Context: "fork"}, }, "") @@ -348,7 +348,7 @@ func TestSkillsToolset_Tools_WithForkSkills(t *testing.T) { } func TestSkillsToolset_Tools_NoForkSkills(t *testing.T) { - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "inline", Description: "Inline skill"}, }, "") @@ -365,7 +365,7 @@ func TestSkillsToolset_PrepareForkSubSession(t *testing.T) { skillFile := filepath.Join(tmpDir, "SKILL.md") require.NoError(t, os.WriteFile(skillFile, []byte("system instructions"), 0o644)) - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "forked", Description: "Forked", Context: "fork", FilePath: skillFile, BaseDir: tmpDir, Model: "openai/gpt-4o-mini"}, }, "") @@ -383,7 +383,7 @@ func TestSkillsToolset_PrepareForkSubSession_NoModelOverride(t *testing.T) { skillFile := filepath.Join(tmpDir, "SKILL.md") require.NoError(t, os.WriteFile(skillFile, []byte("system instructions"), 0o644)) - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ // No Model set in the frontmatter — Prepared.Model must be empty. {Name: "forked", Description: "Forked", Context: "fork", FilePath: skillFile, BaseDir: tmpDir}, }, "") @@ -395,7 +395,7 @@ func TestSkillsToolset_PrepareForkSubSession_NoModelOverride(t *testing.T) { } func TestSkillsToolset_PrepareForkSubSession_NotFound(t *testing.T) { - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "exists", Description: "Exists", Context: "fork", FilePath: "/tmp/nonexistent"}, }, "") @@ -411,7 +411,7 @@ func TestSkillsToolset_PrepareForkSubSession_NotFork(t *testing.T) { skillFile := filepath.Join(tmpDir, "SKILL.md") require.NoError(t, os.WriteFile(skillFile, []byte("inline"), 0o644)) - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ // No Context: "fork" — this is an inline skill. {Name: "inline-only", Description: "Inline", FilePath: skillFile, BaseDir: tmpDir}, }, "") @@ -425,7 +425,7 @@ func TestSkillsToolset_PrepareForkSubSession_NotFork(t *testing.T) { } func TestSkillsToolset_PrepareForkSubSession_ReadFailure(t *testing.T) { - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ // FilePath does not exist on disk; ReadSkillContent will fail. {Name: "forked", Description: "Forked", Context: "fork", FilePath: "/does/not/exist/SKILL.md"}, }, "") @@ -438,7 +438,7 @@ func TestSkillsToolset_PrepareForkSubSession_ReadFailure(t *testing.T) { } func TestSkillsToolset_Tools_ForkAndFiles(t *testing.T) { - st := NewSkillsToolset([]skills.Skill{ + st := New([]skills.Skill{ {Name: "full", Description: "Full skill", Context: "fork", Files: []string{"SKILL.md", "ref.md"}}, }, "") diff --git a/pkg/tools/builtin/tasks/tasks.go b/pkg/tools/builtin/tasks/tasks.go index d8e3d281d..311a342b7 100644 --- a/pkg/tools/builtin/tasks/tasks.go +++ b/pkg/tools/builtin/tasks/tasks.go @@ -14,8 +14,11 @@ import ( "github.com/google/uuid" + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/path" "github.com/docker/docker-agent/pkg/tools" + "github.com/docker/docker-agent/pkg/tools/toolsetpath" ) const ( @@ -88,25 +91,43 @@ type taskStore struct { Tasks map[string]Task `json:"tasks"` } -type Tool struct { +type ToolSet struct { mu sync.Mutex filePath string basePath string } var ( - _ tools.ToolSet = (*Tool)(nil) - _ tools.Instructable = (*Tool)(nil) + _ tools.ToolSet = (*ToolSet)(nil) + _ tools.Instructable = (*ToolSet)(nil) ) -func NewTasksTool(storagePath string) *Tool { - return &Tool{ +// CreateToolSet is used by the tools registry. +func CreateToolSet(toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { + toolsetPath := toolset.Path + if toolsetPath == "" { + toolsetPath = "tasks.json" + } + + validatedPath, err := toolsetpath.Resolve(toolsetPath, parentDir, runConfig) + if err != nil { + return nil, fmt.Errorf("invalid tasks storage path: %w", err) + } + if err := os.MkdirAll(filepath.Dir(validatedPath), 0o700); err != nil { + return nil, fmt.Errorf("failed to create tasks storage directory: %w", err) + } + + return New(validatedPath), nil +} + +func New(storagePath string) *ToolSet { + return &ToolSet{ filePath: storagePath, basePath: filepath.Dir(storagePath), } } -func (t *Tool) Instructions() string { +func (t *ToolSet) Instructions() string { return `## Task Tools Persistent task management with priorities (critical > high > medium > low), statuses (pending, in_progress, done, blocked), and dependencies. Tasks persist across sessions. @@ -114,7 +135,7 @@ Persistent task management with priorities (critical > high > medium > low), sta A task is automatically blocked if any dependency is not done. Use next_task to get the highest-priority actionable task.` } -func (t *Tool) load() taskStore { +func (t *ToolSet) load() taskStore { data, err := os.ReadFile(t.filePath) if err != nil { return taskStore{Tasks: make(map[string]Task)} @@ -129,7 +150,7 @@ func (t *Tool) load() taskStore { return store } -func (t *Tool) save(store taskStore) error { +func (t *ToolSet) save(store taskStore) error { if err := os.MkdirAll(filepath.Dir(t.filePath), 0o700); err != nil { return fmt.Errorf("creating storage directory: %w", err) } @@ -178,7 +199,7 @@ func now() string { return time.Now().UTC().Format(time.RFC3339) } -func (t *Tool) resolveDescription(description, filePath string) (string, error) { +func (t *ToolSet) resolveDescription(description, filePath string) (string, error) { if filePath != "" { validatedPath, err := path.ValidatePathInDirectory(filePath, t.basePath) if err != nil { @@ -254,7 +275,7 @@ type RemoveDependencyArgs struct { // Tool handlers -func (t *Tool) createTask(_ context.Context, params CreateTaskArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) createTask(_ context.Context, params CreateTaskArgs) (*tools.ToolCallResult, error) { desc, err := t.resolveDescription(params.Description, params.Path) if err != nil { return tools.ResultError(err.Error()), nil @@ -305,7 +326,7 @@ func (t *Tool) createTask(_ context.Context, params CreateTaskArgs) (*tools.Tool return taskResult(task), nil } -func (t *Tool) getTask(_ context.Context, params GetTaskArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) getTask(_ context.Context, params GetTaskArgs) (*tools.ToolCallResult, error) { t.mu.Lock() defer t.mu.Unlock() @@ -318,7 +339,7 @@ func (t *Tool) getTask(_ context.Context, params GetTaskArgs) (*tools.ToolCallRe return taskWithEffectiveResult(task, store.Tasks), nil } -func (t *Tool) updateTask(_ context.Context, params UpdateTaskArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) updateTask(_ context.Context, params UpdateTaskArgs) (*tools.ToolCallResult, error) { t.mu.Lock() defer t.mu.Unlock() @@ -372,7 +393,7 @@ func (t *Tool) updateTask(_ context.Context, params UpdateTaskArgs) (*tools.Tool return taskResult(task), nil } -func (t *Tool) deleteTask(_ context.Context, params DeleteTaskArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) deleteTask(_ context.Context, params DeleteTaskArgs) (*tools.ToolCallResult, error) { t.mu.Lock() defer t.mu.Unlock() @@ -401,7 +422,7 @@ func (t *Tool) deleteTask(_ context.Context, params DeleteTaskArgs) (*tools.Tool return tools.ResultJSON(map[string]string{"deleted": params.ID}), nil } -func (t *Tool) listTasks(_ context.Context, params ListTasksArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) listTasks(_ context.Context, params ListTasksArgs) (*tools.ToolCallResult, error) { t.mu.Lock() defer t.mu.Unlock() @@ -438,7 +459,7 @@ func (t *Tool) listTasks(_ context.Context, params ListTasksArgs) (*tools.ToolCa return tools.ResultJSON(tasks), nil } -func (t *Tool) nextTask(_ context.Context, _ tools.ToolCall) (*tools.ToolCallResult, error) { +func (t *ToolSet) nextTask(_ context.Context, _ tools.ToolCall) (*tools.ToolCallResult, error) { t.mu.Lock() defer t.mu.Unlock() @@ -461,7 +482,7 @@ func (t *Tool) nextTask(_ context.Context, _ tools.ToolCall) (*tools.ToolCallRes return tools.ResultSuccess("No actionable tasks. Everything is either done or blocked."), nil } -func (t *Tool) addDependency(_ context.Context, params AddDependencyArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) addDependency(_ context.Context, params AddDependencyArgs) (*tools.ToolCallResult, error) { t.mu.Lock() defer t.mu.Unlock() @@ -493,7 +514,7 @@ func (t *Tool) addDependency(_ context.Context, params AddDependencyArgs) (*tool return taskResult(task), nil } -func (t *Tool) removeDependency(_ context.Context, params RemoveDependencyArgs) (*tools.ToolCallResult, error) { +func (t *ToolSet) removeDependency(_ context.Context, params RemoveDependencyArgs) (*tools.ToolCallResult, error) { t.mu.Lock() defer t.mu.Unlock() @@ -531,7 +552,7 @@ func taskWithEffectiveResult(task Task, tasks map[string]Task) *tools.ToolCallRe }) } -func (t *Tool) Tools(_ context.Context) ([]tools.Tool, error) { +func (t *ToolSet) Tools(_ context.Context) ([]tools.Tool, error) { return []tools.Tool{ { Name: ToolNameCreateTask, diff --git a/pkg/tools/builtin/tasks/tasks_test.go b/pkg/tools/builtin/tasks/tasks_test.go index f58c1e684..354a2da04 100644 --- a/pkg/tools/builtin/tasks/tasks_test.go +++ b/pkg/tools/builtin/tasks/tasks_test.go @@ -12,10 +12,10 @@ import ( "github.com/docker/docker-agent/pkg/tools" ) -func newTestTasksTool(t *testing.T) *Tool { +func newTestTasksTool(t *testing.T) *ToolSet { t.Helper() dir := t.TempDir() - return NewTasksTool(filepath.Join(dir, "tasks.json")) + return New(filepath.Join(dir, "tasks.json")) } func TestTasksTool_DisplayNames(t *testing.T) { @@ -501,13 +501,13 @@ func TestTasksTool_Persistence(t *testing.T) { dir := t.TempDir() storagePath := filepath.Join(dir, "tasks.json") - tool1 := NewTasksTool(storagePath) + tool1 := New(storagePath) r, err := tool1.createTask(t.Context(), CreateTaskArgs{Title: "Persistent"}) require.NoError(t, err) var task Task require.NoError(t, json.Unmarshal([]byte(r.Output), &task)) - tool2 := NewTasksTool(storagePath) + tool2 := New(storagePath) result, err := tool2.getTask(t.Context(), GetTaskArgs{ID: task.ID}) require.NoError(t, err) assert.False(t, result.IsError) diff --git a/pkg/tools/builtin/think/think.go b/pkg/tools/builtin/think/think.go index 2c10c85c3..c6d6e2633 100644 --- a/pkg/tools/builtin/think/think.go +++ b/pkg/tools/builtin/think/think.go @@ -9,30 +9,29 @@ import ( const ToolNameThink = "think" -type Tool struct { - thoughts []string +// CreateToolSet is used by the tools registry. +func CreateToolSet() (tools.ToolSet, error) { + return New(), nil } -// Verify interface compliance -var ( - _ tools.ToolSet = (*Tool)(nil) - _ tools.Instructable = (*Tool)(nil) -) +type ToolSet struct { + thoughts []string +} type Args struct { Thought string `json:"thought" jsonschema:"The thought to think about"` } -func (t *Tool) callTool(_ context.Context, params Args) (*tools.ToolCallResult, error) { +func (t *ToolSet) callTool(_ context.Context, params Args) (*tools.ToolCallResult, error) { t.thoughts = append(t.thoughts, params.Thought) return tools.ResultSuccess("Thoughts:\n" + strings.Join(t.thoughts, "\n")), nil } -func NewThinkTool() *Tool { - return &Tool{} +func New() *ToolSet { + return &ToolSet{} } -func (t *Tool) Instructions() string { +func (t *ToolSet) Instructions() string { return `## Think Tool Use the think tool as a scratchpad before acting. Think to: @@ -42,7 +41,7 @@ Use the think tool as a scratchpad before acting. Think to: - Reason through complex multi-step problems` } -func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { +func (t *ToolSet) Tools(context.Context) ([]tools.Tool, error) { return []tools.Tool{ { Name: ToolNameThink, diff --git a/pkg/tools/builtin/think/think_test.go b/pkg/tools/builtin/think/think_test.go index ac57587bf..87a4efca4 100644 --- a/pkg/tools/builtin/think/think_test.go +++ b/pkg/tools/builtin/think/think_test.go @@ -10,7 +10,7 @@ import ( ) func TestThinkTool_Handler(t *testing.T) { - tool := NewThinkTool() + tool := New() result, err := tool.callTool(t.Context(), Args{Thought: "This is a test thought"}) require.NoError(t, err) @@ -24,7 +24,7 @@ func TestThinkTool_Handler(t *testing.T) { } func TestThinkTool_OutputSchema(t *testing.T) { - tool := NewThinkTool() + tool := New() allTools, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -36,7 +36,7 @@ func TestThinkTool_OutputSchema(t *testing.T) { } func TestThinkTool_ParametersAreObjects(t *testing.T) { - tool := NewThinkTool() + tool := New() allTools, err := tool.Tools(t.Context()) require.NoError(t, err) diff --git a/pkg/tools/builtin/todo/todo.go b/pkg/tools/builtin/todo/todo.go index 58dd73f98..ec9a87226 100644 --- a/pkg/tools/builtin/todo/todo.go +++ b/pkg/tools/builtin/todo/todo.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "github.com/docker/docker-agent/pkg/concurrent" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/tools" ) @@ -19,15 +20,18 @@ const ( ToolNameListTodos = "list_todos" ) -type Tool struct { - handler *todoHandler +// CreateToolSet is used by the tools registry. +func CreateToolSet(toolset latest.Toolset) (tools.ToolSet, error) { + if toolset.Shared { + return newSharedTodoTool(), nil + } + + return New(), nil } -// Verify interface compliance -var ( - _ tools.ToolSet = (*Tool)(nil) - _ tools.Instructable = (*Tool)(nil) -) +type ToolSet struct { + handler *todoHandler +} type Todo struct { ID string `json:"id" jsonschema:"ID of the todo item"` @@ -131,7 +135,7 @@ func (s *MemoryTodoStorage) Clear(_ context.Context) { } // Option is a functional option for configuring a Tool. -type Option func(*Tool) +type Option func(*ToolSet) // WithStorage sets a custom storage implementation for the Tool. // The provided storage must not be nil. @@ -139,7 +143,7 @@ func WithStorage(storage Storage) Option { if storage == nil { panic("todo: storage must not be nil") } - return func(t *Tool) { + return func(t *ToolSet) { t.handler.storage = storage } } @@ -149,10 +153,10 @@ type todoHandler struct { nextID atomic.Int64 } -var NewSharedTodoTool = sync.OnceValue(func() *Tool { return NewTodoTool() }) +var newSharedTodoTool = sync.OnceValue(func() *ToolSet { return New() }) -func NewTodoTool(opts ...Option) *Tool { - t := &Tool{ +func New(opts ...Option) *ToolSet { + t := &ToolSet{ handler: &todoHandler{ storage: NewMemoryTodoStorage(), }, @@ -163,7 +167,7 @@ func NewTodoTool(opts ...Option) *Tool { return t } -func (t *Tool) Instructions() string { +func (t *ToolSet) Instructions() string { return `## Todo Tools Track task progress with todos: @@ -288,7 +292,7 @@ func (h *todoHandler) listTodos(ctx context.Context, _ tools.ToolCall) (*tools.T return h.jsonResult(ctx, out) } -func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { +func (t *ToolSet) Tools(context.Context) ([]tools.Tool, error) { return []tools.Tool{ { Name: ToolNameCreateTodo, diff --git a/pkg/tools/builtin/todo/todo_test.go b/pkg/tools/builtin/todo/todo_test.go index bd91667a2..b6fbb630e 100644 --- a/pkg/tools/builtin/todo/todo_test.go +++ b/pkg/tools/builtin/todo/todo_test.go @@ -11,7 +11,7 @@ import ( ) func TestTodoTool_DisplayNames(t *testing.T) { - tool := NewTodoTool() + tool := New() all, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -24,7 +24,7 @@ func TestTodoTool_DisplayNames(t *testing.T) { func TestTodoTool_CreateTodo(t *testing.T) { storage := NewMemoryTodoStorage() - tool := NewTodoTool(WithStorage(storage)) + tool := New(WithStorage(storage)) result, err := tool.handler.createTodo(t.Context(), CreateTodoArgs{ Description: "Test todo item", @@ -48,7 +48,7 @@ func TestTodoTool_CreateTodo(t *testing.T) { func TestTodoTool_CreateTodos(t *testing.T) { storage := NewMemoryTodoStorage() - tool := NewTodoTool(WithStorage(storage)) + tool := New(WithStorage(storage)) result, err := tool.handler.createTodos(t.Context(), CreateTodosArgs{ Descriptions: []string{"First", "Second", "Third"}, @@ -88,7 +88,7 @@ func TestTodoTool_CreateTodos(t *testing.T) { } func TestTodoTool_ListTodos(t *testing.T) { - tool := NewTodoTool() + tool := New() descs := []string{"First", "Second", "Third"} for _, d := range descs { @@ -116,7 +116,7 @@ func TestTodoTool_ListTodos(t *testing.T) { } func TestTodoTool_ListTodos_Empty(t *testing.T) { - tool := NewTodoTool() + tool := New() result, err := tool.handler.listTodos(t.Context(), tools.ToolCall{}) require.NoError(t, err) @@ -131,7 +131,7 @@ func TestTodoTool_ListTodos_Empty(t *testing.T) { func TestTodoTool_UpdateTodos(t *testing.T) { storage := NewMemoryTodoStorage() - tool := NewTodoTool(WithStorage(storage)) + tool := New(WithStorage(storage)) _, err := tool.handler.createTodos(t.Context(), CreateTodosArgs{ Descriptions: []string{"First", "Second", "Third"}, @@ -178,7 +178,7 @@ func TestTodoTool_UpdateTodos(t *testing.T) { func TestTodoTool_UpdateTodos_PartialFailure(t *testing.T) { storage := NewMemoryTodoStorage() - tool := NewTodoTool(WithStorage(storage)) + tool := New(WithStorage(storage)) _, err := tool.handler.createTodos(t.Context(), CreateTodosArgs{ Descriptions: []string{"First", "Second"}, @@ -211,7 +211,7 @@ func TestTodoTool_UpdateTodos_PartialFailure(t *testing.T) { } func TestTodoTool_UpdateTodos_AllNotFound(t *testing.T) { - tool := NewTodoTool() + tool := New() result, err := tool.handler.updateTodos(t.Context(), UpdateTodosArgs{ Updates: []Update{ @@ -232,7 +232,7 @@ func TestTodoTool_UpdateTodos_AllNotFound(t *testing.T) { func TestTodoTool_UpdateTodos_AllCompleted_NoAutoRemoval(t *testing.T) { storage := NewMemoryTodoStorage() - tool := NewTodoTool(WithStorage(storage)) + tool := New(WithStorage(storage)) _, err := tool.handler.createTodos(t.Context(), CreateTodosArgs{ Descriptions: []string{"First", "Second"}, @@ -264,7 +264,7 @@ func TestTodoTool_UpdateTodos_AllCompleted_NoAutoRemoval(t *testing.T) { func TestTodoTool_WithStorage(t *testing.T) { storage := NewMemoryTodoStorage() - tool := NewTodoTool(WithStorage(storage)) + tool := New(WithStorage(storage)) _, err := tool.handler.createTodo(t.Context(), CreateTodoArgs{Description: "Test item"}) require.NoError(t, err) @@ -280,7 +280,7 @@ func TestTodoTool_WithStorage_NilPanics(t *testing.T) { } func TestTodoTool_OutputSchema(t *testing.T) { - tool := NewTodoTool() + tool := New() allTools, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -292,7 +292,7 @@ func TestTodoTool_OutputSchema(t *testing.T) { } func TestTodoTool_ParametersAreObjects(t *testing.T) { - tool := NewTodoTool() + tool := New() allTools, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -307,7 +307,7 @@ func TestTodoTool_ParametersAreObjects(t *testing.T) { } func TestTodoTool_CreateTodo_FullStateOutput(t *testing.T) { - tool := NewTodoTool() + tool := New() // Create first todo result1, err := tool.handler.createTodo(t.Context(), CreateTodoArgs{Description: "First"}) @@ -328,7 +328,7 @@ func TestTodoTool_CreateTodo_FullStateOutput(t *testing.T) { } func TestTodoTool_UpdateTodos_FullStateOutput(t *testing.T) { - tool := NewTodoTool() + tool := New() _, err := tool.handler.createTodos(t.Context(), CreateTodosArgs{ Descriptions: []string{"A", "B", "C"}, diff --git a/pkg/tools/builtin/transfertask/transfertask.go b/pkg/tools/builtin/transfertask/transfertask.go index 4e0420da9..84fdd1a74 100644 --- a/pkg/tools/builtin/transfertask/transfertask.go +++ b/pkg/tools/builtin/transfertask/transfertask.go @@ -8,9 +8,9 @@ import ( const ToolNameTransferTask = "transfer_task" -type Tool struct{} +type ToolSet struct{} -var _ tools.ToolSet = (*Tool)(nil) +var _ tools.ToolSet = (*ToolSet)(nil) type Args struct { Agent string `json:"agent" jsonschema:"The name of the agent to transfer the task to."` @@ -18,11 +18,11 @@ type Args struct { ExpectedOutput string `json:"expected_output" jsonschema:"The expected output from the member (optional)."` } -func NewTransferTaskTool() *Tool { - return &Tool{} +func New() *ToolSet { + return &ToolSet{} } -func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { +func (t *ToolSet) Tools(context.Context) ([]tools.Tool, error) { return []tools.Tool{ { Name: ToolNameTransferTask, diff --git a/pkg/tools/builtin/transfertask/transfertask_test.go b/pkg/tools/builtin/transfertask/transfertask_test.go index bd4cea94b..2b2da1cba 100644 --- a/pkg/tools/builtin/transfertask/transfertask_test.go +++ b/pkg/tools/builtin/transfertask/transfertask_test.go @@ -11,12 +11,12 @@ import ( ) func TestNewTaskTool(t *testing.T) { - tool := NewTransferTaskTool() + tool := New() assert.NotNil(t, tool) } func TestTaskTool_Instructions(t *testing.T) { - tool := NewTransferTaskTool() + tool := New() // Tool doesn't implement Instructable _, ok := any(tool).(tools.Instructable) @@ -24,7 +24,7 @@ func TestTaskTool_Instructions(t *testing.T) { } func TestTaskTool_Tools(t *testing.T) { - tool := NewTransferTaskTool() + tool := New() allTools, err := tool.Tools(t.Context()) @@ -65,7 +65,7 @@ func TestTaskTool_Tools(t *testing.T) { } func TestTaskTool_DisplayNames(t *testing.T) { - tool := NewTransferTaskTool() + tool := New() all, err := tool.Tools(t.Context()) require.NoError(t, err) @@ -78,7 +78,7 @@ func TestTaskTool_DisplayNames(t *testing.T) { } func TestTaskTool_StartStop(t *testing.T) { - tool := NewTransferTaskTool() + tool := New() // Tool doesn't need to implement Startable - // it has no initialization or cleanup requirements diff --git a/pkg/tools/builtin/userprompt/user_prompt_test.go b/pkg/tools/builtin/userprompt/user_prompt_test.go index a6c8d59db..bf811113e 100644 --- a/pkg/tools/builtin/userprompt/user_prompt_test.go +++ b/pkg/tools/builtin/userprompt/user_prompt_test.go @@ -13,7 +13,7 @@ import ( ) func TestUserPromptTool_AcceptResponse(t *testing.T) { - tool := NewUserPromptTool() + tool := New() tool.SetElicitationHandler(func(_ context.Context, req *mcp.ElicitParams) (tools.ElicitationResult, error) { assert.Equal(t, "What is your name?", req.Message) @@ -35,9 +35,9 @@ func TestUserPromptTool_AcceptResponse(t *testing.T) { } func TestUserPromptTool_DeclineResponse(t *testing.T) { - tool := NewUserPromptTool() + tool := New() - tool.SetElicitationHandler(func(_ context.Context, _ *mcp.ElicitParams) (tools.ElicitationResult, error) { + tool.SetElicitationHandler(func(context.Context, *mcp.ElicitParams) (tools.ElicitationResult, error) { return tools.ElicitationResult{ Action: tools.ElicitationActionDecline, }, nil @@ -55,9 +55,9 @@ func TestUserPromptTool_DeclineResponse(t *testing.T) { } func TestUserPromptTool_CancelResponse(t *testing.T) { - tool := NewUserPromptTool() + tool := New() - tool.SetElicitationHandler(func(_ context.Context, _ *mcp.ElicitParams) (tools.ElicitationResult, error) { + tool.SetElicitationHandler(func(context.Context, *mcp.ElicitParams) (tools.ElicitationResult, error) { return tools.ElicitationResult{ Action: tools.ElicitationActionCancel, }, nil @@ -74,7 +74,7 @@ func TestUserPromptTool_CancelResponse(t *testing.T) { } func TestUserPromptTool_WithSchema(t *testing.T) { - tool := NewUserPromptTool() + tool := New() var receivedSchema any tool.SetElicitationHandler(func(_ context.Context, req *mcp.ElicitParams) (tools.ElicitationResult, error) { diff --git a/pkg/tools/builtin/userprompt/userprompt.go b/pkg/tools/builtin/userprompt/userprompt.go index 421845a5e..54419b61a 100644 --- a/pkg/tools/builtin/userprompt/userprompt.go +++ b/pkg/tools/builtin/userprompt/userprompt.go @@ -12,16 +12,18 @@ import ( const ToolNameUserPrompt = "user_prompt" -type Tool struct { - elicitationHandler tools.ElicitationHandler +// CreateToolSet is used by the tools registry. +func CreateToolSet() (tools.ToolSet, error) { + return New(), nil } -// Verify interface compliance -var ( - _ tools.ToolSet = (*Tool)(nil) - _ tools.Elicitable = (*Tool)(nil) - _ tools.Instructable = (*Tool)(nil) -) +func New() *ToolSet { + return &ToolSet{} +} + +type ToolSet struct { + elicitationHandler tools.ElicitationHandler +} type Args struct { Message string `json:"message" jsonschema:"The message/question to display to the user"` @@ -34,15 +36,40 @@ type Response struct { Content map[string]any `json:"content,omitempty" jsonschema:"The user response data (only present when action is accept)"` } -func NewUserPromptTool() *Tool { - return &Tool{} +func (t *ToolSet) SetElicitationHandler(elicitationHandler tools.ElicitationHandler) { + t.elicitationHandler = elicitationHandler } -func (t *Tool) SetElicitationHandler(handler tools.ElicitationHandler) { - t.elicitationHandler = handler +func (t *ToolSet) Instructions() string { + return `## User Prompt Tool + +Ask the user a question when you need clarification, input, or a decision. + +Optionally provide a JSON schema to structure the response: +- Enum: {"type": "string", "enum": ["option1", "option2"], "title": "Select"} +- Object: {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]} + +Response contains "action" (accept/decline/cancel) and "content" (user data when accepted).` +} + +func (t *ToolSet) Tools(context.Context) ([]tools.Tool, error) { + return []tools.Tool{ + { + Name: ToolNameUserPrompt, + Category: "user_prompt", + Description: "Ask the user a question and wait for their response. Use this when you need interactive input, clarification, or confirmation from the user. Optionally provide a JSON schema to define the expected response structure.", + Parameters: tools.MustSchemaFor[Args](), + OutputSchema: tools.MustSchemaFor[Response](), + Handler: tools.NewHandler(t.userPrompt), + Annotations: tools.ToolAnnotations{ + ReadOnlyHint: true, + Title: "User Prompt", + }, + }, + }, nil } -func (t *Tool) userPrompt(ctx context.Context, params Args) (*tools.ToolCallResult, error) { +func (t *ToolSet) userPrompt(ctx context.Context, params Args) (*tools.ToolCallResult, error) { if t.elicitationHandler == nil { return tools.ResultError("user_prompt tool is not available in this context (no elicitation handler configured)"), nil } @@ -79,32 +106,3 @@ func (t *Tool) userPrompt(ctx context.Context, params Args) (*tools.ToolCallResu return tools.ResultSuccess(string(responseJSON)), nil } - -func (t *Tool) Instructions() string { - return `## User Prompt Tool - -Ask the user a question when you need clarification, input, or a decision. - -Optionally provide a JSON schema to structure the response: -- Enum: {"type": "string", "enum": ["option1", "option2"], "title": "Select"} -- Object: {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]} - -Response contains "action" (accept/decline/cancel) and "content" (user data when accepted).` -} - -func (t *Tool) Tools(context.Context) ([]tools.Tool, error) { - return []tools.Tool{ - { - Name: ToolNameUserPrompt, - Category: "user_prompt", - Description: "Ask the user a question and wait for their response. Use this when you need interactive input, clarification, or confirmation from the user. Optionally provide a JSON schema to define the expected response structure.", - Parameters: tools.MustSchemaFor[Args](), - OutputSchema: tools.MustSchemaFor[Response](), - Handler: tools.NewHandler(t.userPrompt), - Annotations: tools.ToolAnnotations{ - ReadOnlyHint: true, - Title: "User Prompt", - }, - }, - }, nil -} diff --git a/pkg/tools/lifecycle/policy.go b/pkg/tools/lifecycle/policy.go new file mode 100644 index 000000000..5deed9a50 --- /dev/null +++ b/pkg/tools/lifecycle/policy.go @@ -0,0 +1,67 @@ +package lifecycle + +import ( + "log/slog" + + "github.com/docker/docker-agent/pkg/config/latest" +) + +// PolicyFromConfig converts a latest.LifecycleConfig into a Policy. nil cfg +// returns the resilient default policy. +func PolicyFromConfig(name string, cfg *latest.LifecycleConfig) Policy { + policy := profilePolicy(profileName(cfg)) + policy.Logger = slog.With("component", "supervisor", "toolset", name) + + if cfg == nil { + return policy + } + if cfg.Restart != "" { + policy.Restart = ParseRestart(cfg.Restart) + } + if cfg.MaxRestarts != 0 { + policy.MaxAttempts = cfg.MaxRestarts + } + if b := cfg.Backoff; b != nil { + if b.Initial.Duration > 0 { + policy.Backoff.Initial = b.Initial.Duration + } + if b.Max.Duration > 0 { + policy.Backoff.Max = b.Max.Duration + } + if b.Multiplier > 0 { + policy.Backoff.Multiplier = b.Multiplier + } + if b.Jitter > 0 { + policy.Backoff.Jitter = b.Jitter + } + } + return policy +} + +func profileName(cfg *latest.LifecycleConfig) string { + if cfg == nil || cfg.Profile == "" { + return latest.LifecycleProfileResilient + } + return cfg.Profile +} + +func profilePolicy(profile string) Policy { + switch profile { + case latest.LifecycleProfileStrict, latest.LifecycleProfileBestEffort: + return Policy{Restart: RestartNever, MaxAttempts: -1} + default: + return Policy{Restart: RestartOnFailure, MaxAttempts: 5} + } +} + +// ParseRestart converts a YAML restart string into the supervisor enum. +func ParseRestart(s string) Restart { + switch s { + case "never": + return RestartNever + case "always": + return RestartAlways + default: + return RestartOnFailure + } +} diff --git a/pkg/tools/lifecycle/policy_test.go b/pkg/tools/lifecycle/policy_test.go new file mode 100644 index 000000000..ad7100ca1 --- /dev/null +++ b/pkg/tools/lifecycle/policy_test.go @@ -0,0 +1,81 @@ +package lifecycle + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/docker/docker-agent/pkg/config/latest" +) + +func TestPolicyFromConfig_NilUsesResilientDefaults(t *testing.T) { + t.Parallel() + p := PolicyFromConfig("test", nil) + assert.Equal(t, RestartOnFailure, p.Restart) + assert.Equal(t, 5, p.MaxAttempts) + assert.NotNil(t, p.Logger) +} + +func TestPolicyFromConfig_StrictProfile(t *testing.T) { + t.Parallel() + p := PolicyFromConfig("test", &latest.LifecycleConfig{ + Profile: latest.LifecycleProfileStrict, + }) + assert.Equal(t, RestartNever, p.Restart) + assert.Equal(t, -1, p.MaxAttempts) +} + +func TestPolicyFromConfig_BestEffortProfile(t *testing.T) { + t.Parallel() + p := PolicyFromConfig("test", &latest.LifecycleConfig{ + Profile: latest.LifecycleProfileBestEffort, + }) + assert.Equal(t, RestartNever, p.Restart) +} + +func TestPolicyFromConfig_ExplicitOverrides(t *testing.T) { + t.Parallel() + cfg := &latest.LifecycleConfig{ + Profile: latest.LifecycleProfileResilient, + Restart: "always", + MaxRestarts: 12, + Backoff: &latest.BackoffConfig{ + Initial: latest.Duration{Duration: 500 * time.Millisecond}, + Max: latest.Duration{Duration: 10 * time.Second}, + Multiplier: 1.5, + Jitter: 0.3, + }, + } + p := PolicyFromConfig("test", cfg) + assert.Equal(t, RestartAlways, p.Restart) + assert.Equal(t, 12, p.MaxAttempts) + assert.Equal(t, 500*time.Millisecond, p.Backoff.Initial) + assert.Equal(t, 10*time.Second, p.Backoff.Max) + assert.InDelta(t, 1.5, p.Backoff.Multiplier, 0.001) + assert.InDelta(t, 0.3, p.Backoff.Jitter, 0.001) +} + +func TestPolicyFromConfig_PartialOverridesKeepProfileDefaults(t *testing.T) { + t.Parallel() + cfg := &latest.LifecycleConfig{ + Profile: latest.LifecycleProfileResilient, + MaxRestarts: 7, + } + p := PolicyFromConfig("test", cfg) + assert.Equal(t, RestartOnFailure, p.Restart, "profile default preserved") + assert.Equal(t, 7, p.MaxAttempts, "explicit override applied") +} + +func TestParseRestart(t *testing.T) { + t.Parallel() + cases := map[string]Restart{ + "": RestartOnFailure, + "on_failure": RestartOnFailure, + "never": RestartNever, + "always": RestartAlways, + } + for in, want := range cases { + assert.Equal(t, want, ParseRestart(in), "input=%q", in) + } +} diff --git a/pkg/tools/mcp/mcp.go b/pkg/tools/mcp/mcp.go index eb9b55240..445880f44 100644 --- a/pkg/tools/mcp/mcp.go +++ b/pkg/tools/mcp/mcp.go @@ -11,17 +11,99 @@ import ( "iter" "log/slog" "net/url" + "os" "strings" "sync" "time" "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/docker/docker-agent/pkg/config" "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/gateway" + "github.com/docker/docker-agent/pkg/js" + "github.com/docker/docker-agent/pkg/toolinstall" "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/tools/lifecycle" + "github.com/docker/docker-agent/pkg/tools/workingdir" ) +// CreateToolSet is used by the tools registry. +func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { + envProvider := runConfig.EnvProvider() + cwd := workingdir.Resolve(toolset.WorkingDir, runConfig.WorkingDir) + + if toolset.WorkingDir != "" && toolset.Ref == "" { + if err := workingdir.CheckDirExists(cwd, "mcp"); err != nil { + return nil, err + } + } + + switch { + case toolset.Ref != "": + mcpServerName := gateway.ParseServerRef(toolset.Ref) + serverSpec, err := gateway.ServerSpec(ctx, mcpServerName) + if err != nil { + return nil, fmt.Errorf("fetching MCP server spec for %q: %w", mcpServerName, err) + } + + if serverSpec.Type == "remote" { + if toolset.WorkingDir != "" { + return nil, fmt.Errorf("working_dir is not supported for MCP toolset %q: ref %q resolves to a remote server (no local subprocess)", + toolset.Name, toolset.Ref) + } + return NewRemoteToolset(toolset.Name, serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil, nil, lifecycle.PolicyFromConfig(toolset.Name, toolset.Lifecycle)), nil + } + + if toolset.WorkingDir != "" { + if err := workingdir.CheckDirExists(cwd, "mcp"); err != nil { + return nil, err + } + } + + env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider) + if err != nil { + return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err) + } + + envProvider := environment.NewMultiProvider( + environment.NewEnvListProvider(env), + envProvider, + ) + + return NewGatewayToolset(ctx, toolset.Name, mcpServerName, serverSpec.Secrets, toolset.Config, envProvider, cwd) + + case toolset.Command != "": + resolvedCommand, err := toolinstall.EnsureCommand(ctx, toolset.Command, toolset.Version) + if err != nil { + slog.WarnContext(ctx, "MCP command not yet available, will retry on next turn", + "command", toolset.Command, "error", err) + resolvedCommand = toolset.Command + } + + env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider) + if err != nil { + return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err) + } + env = append(env, os.Environ()...) + env = toolinstall.PrependBinDirToEnv(env) + + return NewToolsetCommand(toolset.Name, resolvedCommand, toolset.Args, env, cwd, lifecycle.PolicyFromConfig(toolset.Name, toolset.Lifecycle)), nil + + case toolset.Remote.URL != "": + expander := js.NewJsExpander(envProvider) + + headers := expander.ExpandMap(ctx, toolset.Remote.Headers) + remoteURL := expander.Expand(ctx, toolset.Remote.URL, nil) + + return NewRemoteToolset(toolset.Name, remoteURL, toolset.Remote.TransportType, headers, toolset.Remote.OAuth, lifecycle.PolicyFromConfig(toolset.Name, toolset.Lifecycle)), nil + + default: + return nil, errors.New("mcp toolset requires either ref, command, or remote configuration") + } +} + type mcpClient interface { Initialize(ctx context.Context, request *mcp.InitializeRequest) (*mcp.InitializeResult, error) ListTools(ctx context.Context, request *mcp.ListToolsParams) iter.Seq2[*mcp.Tool, error] diff --git a/pkg/tools/toolsetpath/toolsetpath.go b/pkg/tools/toolsetpath/toolsetpath.go new file mode 100644 index 000000000..b259d8b26 --- /dev/null +++ b/pkg/tools/toolsetpath/toolsetpath.go @@ -0,0 +1,38 @@ +package toolsetpath + +import ( + "path/filepath" + + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/path" +) + +// Resolve returns the validated absolute path for a toolset-specific file or +// directory (e.g. memory database, tasks file). It expands ~ and ${VAR}, +// resolves relative paths against runConfig.WorkingDir or parentDir, and +// validates that the result is contained within the base directory. +// +// Resolution rules: +// - Shell patterns (~ and ${VAR}/$VAR) are expanded first. +// - If the expanded path is absolute, basePath is empty (no containment check). +// - If the expanded path is relative and runConfig.WorkingDir is non-empty, +// basePath is runConfig.WorkingDir. +// - If the expanded path is relative and runConfig.WorkingDir is empty, +// basePath is parentDir. +// +// The final path is validated via path.ValidatePathInDirectory to prevent +// directory traversal attacks. +func Resolve(toolsetPath, parentDir string, runConfig *config.RuntimeConfig) (string, error) { + toolsetPath = path.ExpandPath(toolsetPath) + + var basePath string + if filepath.IsAbs(toolsetPath) { + basePath = "" + } else if wd := runConfig.WorkingDir; wd != "" { + basePath = wd + } else { + basePath = parentDir + } + + return path.ValidatePathInDirectory(toolsetPath, basePath) +} diff --git a/pkg/tools/workingdir/workingdir.go b/pkg/tools/workingdir/workingdir.go new file mode 100644 index 000000000..ed19446b2 --- /dev/null +++ b/pkg/tools/workingdir/workingdir.go @@ -0,0 +1,56 @@ +package workingdir + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/docker/docker-agent/pkg/path" +) + +// Resolve returns the effective working directory for a toolset process. +func Resolve(toolsetWorkingDir, agentWorkingDir string) string { + if toolsetWorkingDir == "" { + return agentWorkingDir + } + toolsetWorkingDir = path.ExpandPath(toolsetWorkingDir) + if filepath.IsAbs(toolsetWorkingDir) { + return toolsetWorkingDir + } + if agentWorkingDir != "" { + abs, err := filepath.Abs(filepath.Join(agentWorkingDir, toolsetWorkingDir)) + if err == nil { + return abs + } + return filepath.Join(agentWorkingDir, toolsetWorkingDir) + } + return toolsetWorkingDir +} + +// Default returns the configured agent working directory or the process cwd. +func Default(agentWorkingDir string) string { + if agentWorkingDir != "" { + return agentWorkingDir + } + wd, err := os.Getwd() + if err != nil { + return "." + } + return wd +} + +// CheckDirExists returns an error if the given directory does not exist or is +// not a directory. kind is used only in the error message. +func CheckDirExists(dir, kind string) error { + info, err := os.Stat(dir) + if err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("working_dir %q for %s toolset does not exist", dir, kind) + } + return fmt.Errorf("working_dir %q for %s toolset: %w", dir, kind, err) + } + if !info.IsDir() { + return fmt.Errorf("working_dir %q for %s toolset is not a directory", dir, kind) + } + return nil +} diff --git a/pkg/tools/workingdir/workingdir_test.go b/pkg/tools/workingdir/workingdir_test.go new file mode 100644 index 000000000..52f7514ac --- /dev/null +++ b/pkg/tools/workingdir/workingdir_test.go @@ -0,0 +1,40 @@ +package workingdir + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestResolve(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + + tests := []struct { + name string + toolsetWorkingDir string + agentWorkingDir string + want string + }{ + {name: "empty uses agent working dir", agentWorkingDir: "/workspace", want: "/workspace"}, + {name: "absolute wins", toolsetWorkingDir: "/tmp/app", agentWorkingDir: "/workspace", want: "/tmp/app"}, + {name: "relative joins agent dir", toolsetWorkingDir: "tools/mcp", agentWorkingDir: "/workspace", want: "/workspace/tools/mcp"}, + {name: "relative without agent dir remains relative", toolsetWorkingDir: "tools/mcp", want: "tools/mcp"}, + {name: "tilde expands", toolsetWorkingDir: "~/projects/app", agentWorkingDir: "/workspace", want: filepath.Join(home, "projects", "app")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := Resolve(tt.toolsetWorkingDir, tt.agentWorkingDir) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestResolveEnvVarExpansion(t *testing.T) { + t.Setenv("TEST_WORKING_DIR_VAR", "/custom/path") + + got := Resolve("${TEST_WORKING_DIR_VAR}/app", "/workspace") + assert.Equal(t, "/custom/path/app", got) +}