diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index a0a441970..df5e2e4a7 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -831,6 +831,7 @@ func (r *LocalRuntime) configureToolsetHandlers(a *agent.Agent, events EventSink for _, toolset := range a.ToolSets() { tools.ConfigureHandlers(toolset, r.elicitationHandler, + r.samplingHandler, func() { events.Emit(Authorization(tools.ElicitationActionAccept, a.Name())) }, r.managedOAuth, ) diff --git a/pkg/runtime/sampling.go b/pkg/runtime/sampling.go new file mode 100644 index 000000000..c1ff91559 --- /dev/null +++ b/pkg/runtime/sampling.go @@ -0,0 +1,267 @@ +package runtime + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "log/slog" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/model/provider" + "github.com/docker/docker-agent/pkg/model/provider/options" +) + +// Limits applied to inbound sampling requests to keep a misbehaving or +// malicious MCP server from inflating host memory / token spend without +// any natural backpressure. +const ( + // maxSamplingMessages caps the number of conversation turns we accept + // from a single sampling/createMessage request. + maxSamplingMessages = 256 + // maxSamplingTextBytes caps the size of an individual text block + // (including the system prompt) before we refuse the request. + maxSamplingTextBytes = 1 << 20 // 1 MiB + // maxSamplingBinaryBytes caps the size of an individual image/audio + // block before we refuse to inline it as a data URL. + maxSamplingBinaryBytes = 8 << 20 // 8 MiB +) + +// samplingHandler is the MCP-toolset-side hook that satisfies an inbound +// sampling/createMessage request from a server by driving the host agent's +// own model and returning the resulting message. +// +// The host always remains in control: the request is mapped to the agent's +// configured model (server-supplied ModelPreferences are advisory only), +// only one round-trip is performed (the model's response is returned +// verbatim, not fed back into the loop), and tool use is intentionally +// disabled — sampling is for plain text/image/audio completions, not +// nested agent runs. Per-block size and per-request message-count limits +// keep an unbounded server response from pinning host memory. +func (r *LocalRuntime) samplingHandler(ctx context.Context, req *mcp.CreateMessageParams) (*mcp.CreateMessageResult, error) { + if req == nil { + return nil, errors.New("sampling request is nil") + } + + slog.InfoContext(ctx, "Sampling request received from MCP server", + "messages", len(req.Messages), + "max_tokens", req.MaxTokens, + "system_prompt", req.SystemPrompt != "", + ) + + a := r.CurrentAgent() + if a == nil { + return nil, errors.New("no current agent available to handle sampling request") + } + + messages, err := samplingMessagesToChat(req) + if err != nil { + return nil, fmt.Errorf("converting sampling messages: %w", err) + } + + baseModel := a.Model(ctx) + if baseModel == nil { + return nil, errors.New("current agent has no model configured") + } + + model := provider.CloneWithOptions(ctx, baseModel, samplingModelOptions(req)...) + + stream, err := model.CreateChatCompletionStream(ctx, messages, nil) + if err != nil { + return nil, fmt.Errorf("creating sampling completion stream: %w", err) + } + + content, finishReason, err := drainSamplingStream(stream) + if err != nil { + return nil, fmt.Errorf("reading sampling completion stream: %w", err) + } + + slog.DebugContext(ctx, "Sampling request completed", + "agent", a.Name(), + "model", model.ID().String(), + "finish_reason", finishReason, + "content_bytes", len(content), + ) + + return &mcp.CreateMessageResult{ + Role: mcp.Role("assistant"), + Model: model.ID().String(), + Content: &mcp.TextContent{Text: content}, + StopReason: stopReason(finishReason), + }, nil +} + +// samplingMessagesToChat converts an MCP CreateMessageParams into the +// host's chat.Message slice. The optional system prompt is prepended; +// per-message Content is mapped from the supported MCP block types. +// Oversized payloads and nil/unsupported entries surface as errors so +// the request is rejected rather than silently truncated. +func samplingMessagesToChat(req *mcp.CreateMessageParams) ([]chat.Message, error) { + if len(req.Messages) == 0 { + return nil, errors.New("sampling request contains no messages") + } + if len(req.Messages) > maxSamplingMessages { + return nil, fmt.Errorf("sampling request contains %d messages (limit %d)", + len(req.Messages), maxSamplingMessages) + } + + messages := make([]chat.Message, 0, len(req.Messages)+1) + if req.SystemPrompt != "" { + if len(req.SystemPrompt) > maxSamplingTextBytes { + return nil, fmt.Errorf("sampling system prompt is too large (%d bytes, limit %d)", + len(req.SystemPrompt), maxSamplingTextBytes) + } + messages = append(messages, chat.Message{ + Role: chat.MessageRoleSystem, + Content: req.SystemPrompt, + }) + } + for i, m := range req.Messages { + if m == nil { + return nil, fmt.Errorf("sampling message at index %d is nil", i) + } + role, err := samplingRoleToChat(m.Role) + if err != nil { + return nil, err + } + text, parts, err := samplingContentToChat(m.Content) + if err != nil { + return nil, fmt.Errorf("sampling message at index %d: %w", i, err) + } + messages = append(messages, chat.Message{ + Role: role, + Content: text, + MultiContent: parts, + }) + } + return messages, nil +} + +func samplingRoleToChat(r mcp.Role) (chat.MessageRole, error) { + switch string(r) { + case "user": + return chat.MessageRoleUser, nil + case "assistant": + return chat.MessageRoleAssistant, nil + case "": + // Some servers omit the role for the lone user turn; default to user + // rather than refuse the request, matching most other MCP hosts. + return chat.MessageRoleUser, nil + default: + return "", fmt.Errorf("unsupported sampling role %q", r) + } +} + +// samplingContentToChat maps a single MCP content block to the host's +// chat representation. Text blocks return a Content string; image blocks +// return a MultiContent entry with a data URL the model can consume. +// Audio blocks fall back to a textual placeholder because chat.Message +// does not currently model raw audio; this lets models acknowledge the +// attachment instead of failing the request outright. Oversized blocks +// are rejected so a malicious or buggy server can't pin large blobs in +// host memory. +func samplingContentToChat(c mcp.Content) (string, []chat.MessagePart, error) { + switch v := c.(type) { + case *mcp.TextContent: + if len(v.Text) > maxSamplingTextBytes { + return "", nil, fmt.Errorf("text block too large (%d bytes, limit %d)", + len(v.Text), maxSamplingTextBytes) + } + return v.Text, nil, nil + case *mcp.ImageContent: + if len(v.Data) > maxSamplingBinaryBytes { + return "", nil, fmt.Errorf("image block too large (%d bytes, limit %d)", + len(v.Data), maxSamplingBinaryBytes) + } + return "", []chat.MessagePart{{ + Type: chat.MessagePartTypeImageURL, + ImageURL: &chat.MessageImageURL{ + URL: dataURL(v.MIMEType, v.Data), + }, + }}, nil + case *mcp.AudioContent: + if len(v.Data) > maxSamplingBinaryBytes { + return "", nil, fmt.Errorf("audio block too large (%d bytes, limit %d)", + len(v.Data), maxSamplingBinaryBytes) + } + return fmt.Sprintf("[audio attachment (%s, %d bytes) — not inlined]", + v.MIMEType, len(v.Data)), nil, nil + case nil: + return "", nil, nil + default: + return fmt.Sprintf("[unsupported content type %T]", v), nil, nil + } +} + +func dataURL(mimeType string, data []byte) string { + mt := mimeType + if mt == "" { + mt = "application/octet-stream" + } + return "data:" + mt + ";base64," + base64.StdEncoding.EncodeToString(data) +} + +// samplingModelOptions translates the server's advisory preferences into +// the host's model options. Only MaxTokens is honoured today (with an +// upper bound enforced by the underlying provider); temperature, stop +// sequences, and ModelPreferences are intentionally left to the host's +// configuration. Structured output is explicitly cleared so a request +// cannot inherit the agent's JSON-schema response format and silently +// reshape the model's reply into something the MCP server didn't ask +// for. +func samplingModelOptions(req *mcp.CreateMessageParams) []options.Opt { + opts := []options.Opt{ + options.WithStructuredOutput(nil), + options.WithNoThinking(), + } + if req.MaxTokens > 0 { + opts = append(opts, options.WithMaxTokens(req.MaxTokens)) + } + return opts +} + +// drainSamplingStream reads a chat completion stream to completion and +// returns the concatenated assistant content alongside the final finish +// reason. The stream is always closed before returning. +func drainSamplingStream(stream chat.MessageStream) (string, chat.FinishReason, error) { + defer stream.Close() + + var content strings.Builder + var finishReason chat.FinishReason + for { + response, err := stream.Recv() + if errors.Is(err, io.EOF) { + return content.String(), finishReason, nil + } + if err != nil { + return "", "", err + } + if len(response.Choices) > 0 { + choice := response.Choices[0] + content.WriteString(choice.Delta.Content) + if choice.FinishReason != "" { + finishReason = choice.FinishReason + } + } + } +} + +// stopReason maps a chat finish reason into the MCP stopReason vocabulary +// used in CreateMessageResult. Unknown values fall back to "endTurn", +// which is the protocol's default for a normal assistant turn. +func stopReason(fr chat.FinishReason) string { + switch fr { + case chat.FinishReasonStop: + return "endTurn" + case chat.FinishReasonLength: + return "maxTokens" + case chat.FinishReasonToolCalls: + return "toolUse" + default: + return "endTurn" + } +} diff --git a/pkg/runtime/sampling_test.go b/pkg/runtime/sampling_test.go new file mode 100644 index 000000000..70f55ca4b --- /dev/null +++ b/pkg/runtime/sampling_test.go @@ -0,0 +1,221 @@ +package runtime + +import ( + "strings" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/chat" +) + +func TestSamplingMessagesToChat(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + req *mcp.CreateMessageParams + want []chat.Message + wantErr bool + }{ + { + name: "single user text message", + req: &mcp.CreateMessageParams{ + Messages: []*mcp.SamplingMessage{ + {Role: "user", Content: &mcp.TextContent{Text: "hello"}}, + }, + }, + want: []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, + }, + { + name: "system prompt is prepended", + req: &mcp.CreateMessageParams{ + SystemPrompt: "be terse", + Messages: []*mcp.SamplingMessage{ + {Role: "user", Content: &mcp.TextContent{Text: "hi"}}, + }, + }, + want: []chat.Message{ + {Role: chat.MessageRoleSystem, Content: "be terse"}, + {Role: chat.MessageRoleUser, Content: "hi"}, + }, + }, + { + name: "user and assistant turns are preserved", + req: &mcp.CreateMessageParams{ + Messages: []*mcp.SamplingMessage{ + {Role: "user", Content: &mcp.TextContent{Text: "ping"}}, + {Role: "assistant", Content: &mcp.TextContent{Text: "pong"}}, + {Role: "user", Content: &mcp.TextContent{Text: "again"}}, + }, + }, + want: []chat.Message{ + {Role: chat.MessageRoleUser, Content: "ping"}, + {Role: chat.MessageRoleAssistant, Content: "pong"}, + {Role: chat.MessageRoleUser, Content: "again"}, + }, + }, + { + name: "image content becomes a data URL multi-part", + req: &mcp.CreateMessageParams{ + Messages: []*mcp.SamplingMessage{ + { + Role: "user", + Content: &mcp.ImageContent{Data: []byte("PNGBYTES"), MIMEType: "image/png"}, + }, + }, + }, + want: []chat.Message{ + { + Role: chat.MessageRoleUser, + MultiContent: []chat.MessagePart{{ + Type: chat.MessagePartTypeImageURL, + ImageURL: &chat.MessageImageURL{ + URL: "data:image/png;base64,UE5HQllURVM=", + }, + }}, + }, + }, + }, + { + name: "audio content falls back to a text placeholder", + req: &mcp.CreateMessageParams{ + Messages: []*mcp.SamplingMessage{ + {Role: "user", Content: &mcp.AudioContent{Data: []byte("WAV"), MIMEType: "audio/wav"}}, + }, + }, + want: []chat.Message{ + {Role: chat.MessageRoleUser, Content: "[audio attachment (audio/wav, 3 bytes) — not inlined]"}, + }, + }, + { + name: "missing role defaults to user", + req: &mcp.CreateMessageParams{ + Messages: []*mcp.SamplingMessage{ + {Content: &mcp.TextContent{Text: "anonymous"}}, + }, + }, + want: []chat.Message{ + {Role: chat.MessageRoleUser, Content: "anonymous"}, + }, + }, + { + name: "unsupported role surfaces as an error", + req: &mcp.CreateMessageParams{ + Messages: []*mcp.SamplingMessage{ + {Role: "tool", Content: &mcp.TextContent{Text: "nope"}}, + }, + }, + wantErr: true, + }, + { + name: "empty request is rejected", + req: &mcp.CreateMessageParams{}, + wantErr: true, + }, + { + name: "system-prompt-only request is rejected", + req: &mcp.CreateMessageParams{ + SystemPrompt: "no messages, only a system prompt", + }, + wantErr: true, + }, + { + name: "nil message entry is rejected", + req: &mcp.CreateMessageParams{ + Messages: []*mcp.SamplingMessage{nil}, + }, + wantErr: true, + }, + { + name: "oversize text block is rejected", + req: &mcp.CreateMessageParams{ + Messages: []*mcp.SamplingMessage{ + {Role: "user", Content: &mcp.TextContent{Text: strings.Repeat("a", maxSamplingTextBytes+1)}}, + }, + }, + wantErr: true, + }, + { + name: "oversize image block is rejected", + req: &mcp.CreateMessageParams{ + Messages: []*mcp.SamplingMessage{ + {Role: "user", Content: &mcp.ImageContent{Data: make([]byte, maxSamplingBinaryBytes+1), MIMEType: "image/png"}}, + }, + }, + wantErr: true, + }, + { + name: "oversize system prompt is rejected", + req: &mcp.CreateMessageParams{ + SystemPrompt: strings.Repeat("a", maxSamplingTextBytes+1), + Messages: []*mcp.SamplingMessage{ + {Role: "user", Content: &mcp.TextContent{Text: "hi"}}, + }, + }, + wantErr: true, + }, + { + name: "too many messages is rejected", + req: &mcp.CreateMessageParams{ + Messages: tooManyMessages(maxSamplingMessages + 1), + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := samplingMessagesToChat(tt.req) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func tooManyMessages(n int) []*mcp.SamplingMessage { + out := make([]*mcp.SamplingMessage, n) + for i := range out { + out[i] = &mcp.SamplingMessage{Role: "user", Content: &mcp.TextContent{Text: "x"}} + } + return out +} + +func TestStopReasonMapping(t *testing.T) { + t.Parallel() + + tests := []struct { + in chat.FinishReason + want string + }{ + {chat.FinishReasonStop, "endTurn"}, + {chat.FinishReasonLength, "maxTokens"}, + {chat.FinishReasonToolCalls, "toolUse"}, + {chat.FinishReasonNull, "endTurn"}, + {"", "endTurn"}, + } + + for _, tt := range tests { + t.Run(string(tt.in), func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, stopReason(tt.in)) + }) + } +} + +func TestDataURL(t *testing.T) { + t.Parallel() + + assert.Equal(t, "data:image/png;base64,UE5HQllURVM=", dataURL("image/png", []byte("PNGBYTES"))) + assert.Equal(t, "data:application/octet-stream;base64,YQ==", dataURL("", []byte("a"))) +} diff --git a/pkg/tools/capabilities.go b/pkg/tools/capabilities.go index 669655d3e..d457bd8c4 100644 --- a/pkg/tools/capabilities.go +++ b/pkg/tools/capabilities.go @@ -46,6 +46,13 @@ type Elicitable interface { SetElicitationHandler(handler ElicitationHandler) } +// Sampleable is implemented by toolsets that support MCP sampling +// (sampling/createMessage). MCP servers use sampling to delegate LLM calls +// back to the host; the handler is expected to drive the host's model. +type Sampleable interface { + SetSamplingHandler(handler SamplingHandler) +} + // OAuthCapable is implemented by toolsets that support OAuth flows. type OAuthCapable interface { SetOAuthSuccessHandler(handler func()) @@ -68,12 +75,16 @@ type ChangeNotifier interface { } // ConfigureHandlers sets all applicable handlers on a toolset. -// It checks for Elicitable and OAuthCapable interfaces and configures them. -// This is a convenience function that handles the capability checking internally. -func ConfigureHandlers(ts ToolSet, elicitHandler ElicitationHandler, oauthHandler func(), managedOAuth bool) { +// It checks for Elicitable, Sampleable and OAuthCapable interfaces and +// configures them. This is a convenience function that handles the capability +// checking internally. +func ConfigureHandlers(ts ToolSet, elicitHandler ElicitationHandler, samplingHandler SamplingHandler, oauthHandler func(), managedOAuth bool) { if e, ok := As[Elicitable](ts); ok { e.SetElicitationHandler(elicitHandler) } + if s, ok := As[Sampleable](ts); ok { + s.SetSamplingHandler(samplingHandler) + } if o, ok := As[OAuthCapable](ts); ok { o.SetOAuthSuccessHandler(oauthHandler) o.SetManagedOAuth(managedOAuth) diff --git a/pkg/tools/mcp/mcp.go b/pkg/tools/mcp/mcp.go index eb9b55240..f0cfd1a5a 100644 --- a/pkg/tools/mcp/mcp.go +++ b/pkg/tools/mcp/mcp.go @@ -29,6 +29,7 @@ type mcpClient interface { ListPrompts(ctx context.Context, request *mcp.ListPromptsParams) iter.Seq2[*mcp.Prompt, error] GetPrompt(ctx context.Context, request *mcp.GetPromptParams) (*mcp.GetPromptResult, error) SetElicitationHandler(handler tools.ElicitationHandler) + SetSamplingHandler(handler tools.SamplingHandler) SetOAuthSuccessHandler(handler func()) SetManagedOAuth(managed bool) SetToolListChangedHandler(handler func()) @@ -91,6 +92,7 @@ var ( var ( _ tools.Instructable = (*Toolset)(nil) _ tools.Elicitable = (*Toolset)(nil) + _ tools.Sampleable = (*Toolset)(nil) _ tools.OAuthCapable = (*Toolset)(nil) _ tools.ChangeNotifier = (*Toolset)(nil) ) @@ -354,6 +356,7 @@ func (c *clientConnector) Connect(ctx context.Context) (lifecycle.Session, error Form: &mcp.FormElicitationCapabilities{}, URL: &mcp.URLElicitationCapabilities{}, }, + Sampling: &mcp.SamplingCapabilities{}, }, }, } @@ -628,6 +631,10 @@ func (ts *Toolset) SetElicitationHandler(handler tools.ElicitationHandler) { ts.mcpClient.SetElicitationHandler(handler) } +func (ts *Toolset) SetSamplingHandler(handler tools.SamplingHandler) { + ts.mcpClient.SetSamplingHandler(handler) +} + func (ts *Toolset) SetOAuthSuccessHandler(handler func()) { ts.mcpClient.SetOAuthSuccessHandler(handler) } diff --git a/pkg/tools/mcp/mcp_test.go b/pkg/tools/mcp/mcp_test.go index 8a80e6264..917319844 100644 --- a/pkg/tools/mcp/mcp_test.go +++ b/pkg/tools/mcp/mcp_test.go @@ -42,6 +42,8 @@ func (m *mockMCPClient) GetPrompt(context.Context, *mcp.GetPromptParams) (*mcp.G func (m *mockMCPClient) SetElicitationHandler(tools.ElicitationHandler) {} +func (m *mockMCPClient) SetSamplingHandler(tools.SamplingHandler) {} + func (m *mockMCPClient) SetOAuthSuccessHandler(func()) {} func (m *mockMCPClient) SetManagedOAuth(bool) {} diff --git a/pkg/tools/mcp/reconnect_test.go b/pkg/tools/mcp/reconnect_test.go index 71ece482b..5bc7c094d 100644 --- a/pkg/tools/mcp/reconnect_test.go +++ b/pkg/tools/mcp/reconnect_test.go @@ -68,6 +68,7 @@ func (m *failingInitClient) GetPrompt(context.Context, *gomcp.GetPromptParams) ( } func (m *failingInitClient) SetElicitationHandler(tools.ElicitationHandler) {} +func (m *failingInitClient) SetSamplingHandler(tools.SamplingHandler) {} func (m *failingInitClient) SetOAuthSuccessHandler(func()) {} func (m *failingInitClient) SetManagedOAuth(bool) {} func (m *failingInitClient) SetToolListChangedHandler(func()) {} diff --git a/pkg/tools/mcp/remote.go b/pkg/tools/mcp/remote.go index cee1df5d3..8aec7bc43 100644 --- a/pkg/tools/mcp/remote.go +++ b/pkg/tools/mcp/remote.go @@ -74,6 +74,7 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *gomcp.InitializeReq opts := &gomcp.ClientOptions{ ElicitationHandler: c.handleElicitationRequest, + CreateMessageHandler: c.handleSamplingRequest, ToolListChangedHandler: toolChanged, PromptListChangedHandler: promptChanged, } diff --git a/pkg/tools/mcp/session_client.go b/pkg/tools/mcp/session_client.go index c3db9f6e8..cdef1c9a8 100644 --- a/pkg/tools/mcp/session_client.go +++ b/pkg/tools/mcp/session_client.go @@ -22,6 +22,7 @@ type sessionClient struct { toolListChangedHandler func() promptListChangedHandler func() elicitationHandler tools.ElicitationHandler + samplingHandler tools.SamplingHandler oauthSuccessHandler func() mu sync.RWMutex } @@ -157,6 +158,36 @@ func (c *sessionClient) SetElicitationHandler(handler tools.ElicitationHandler) c.mu.Unlock() } +// handleSamplingRequest forwards incoming sampling/createMessage requests +// from the MCP server to the registered handler. It is used as the gomcp +// CreateMessageHandler callback for both stdio and remote clients. +func (c *sessionClient) handleSamplingRequest(ctx context.Context, req *gomcp.CreateMessageRequest) (*gomcp.CreateMessageResult, error) { + slog.DebugContext(ctx, "Received sampling request from MCP server", "messages", len(req.Params.Messages)) + + c.mu.RLock() + handler := c.samplingHandler + c.mu.RUnlock() + + if handler == nil { + return nil, errors.New("no sampling handler configured") + } + + result, err := handler(ctx, req.Params) + if err != nil { + return nil, fmt.Errorf("sampling failed: %w", err) + } + + return result, nil +} + +// SetSamplingHandler sets the handler that processes sampling requests +// from the MCP server. +func (c *sessionClient) SetSamplingHandler(handler tools.SamplingHandler) { + c.mu.Lock() + c.samplingHandler = handler + c.mu.Unlock() +} + // requestElicitation invokes the registered elicitation handler directly. // This is used by the OAuth transport to trigger elicitation outside of // the normal MCP request flow. diff --git a/pkg/tools/mcp/stdio.go b/pkg/tools/mcp/stdio.go index 01e3fab25..feb4e3ac5 100644 --- a/pkg/tools/mcp/stdio.go +++ b/pkg/tools/mcp/stdio.go @@ -38,9 +38,10 @@ func (c *stdioMCPClient) Initialize(ctx context.Context, _ *gomcp.InitializeRequ toolChanged, promptChanged := c.notificationHandlers() - // Create client options with elicitation and notification support + // Create client options with elicitation, sampling, and notification support opts := &gomcp.ClientOptions{ ElicitationHandler: c.handleElicitationRequest, + CreateMessageHandler: c.handleSamplingRequest, ToolListChangedHandler: toolChanged, PromptListChangedHandler: promptChanged, } diff --git a/pkg/tools/sampling.go b/pkg/tools/sampling.go new file mode 100644 index 000000000..0bdb24e35 --- /dev/null +++ b/pkg/tools/sampling.go @@ -0,0 +1,17 @@ +package tools + +import ( + "context" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// SamplingHandler is a function type that handles sampling/createMessage +// requests from an MCP server. +// +// MCP servers can use sampling to ask the host application's LLM to generate +// text on their behalf. The host is in control: it may inspect, modify, or +// decline the request, and it decides which model is used. The handler is +// expected to call the host's model with the supplied messages and return +// the model's response (or an error if the request was declined or failed). +type SamplingHandler func(ctx context.Context, req *mcp.CreateMessageParams) (*mcp.CreateMessageResult, error)