diff --git a/dotnet/README.md b/dotnet/README.md index 0f67fb11a..cbaa5da14 100644 --- a/dotnet/README.md +++ b/dotnet/README.md @@ -488,6 +488,95 @@ var safeLookup = AIFunctionFactory.Create( }); ``` +### Commands + +Register slash commands so that users of the CLI's TUI can invoke custom actions via `/commandName`. Each command has a `Name`, optional `Description`, and a `Handler` called when the user executes it. + +```csharp +var session = await client.CreateSessionAsync(new SessionConfig +{ + Model = "gpt-5", + OnPermissionRequest = PermissionHandler.ApproveAll, + Commands = + [ + new CommandDefinition + { + Name = "deploy", + Description = "Deploy the app to production", + Handler = async (context) => + { + Console.WriteLine($"Deploying with args: {context.Args}"); + // Do work here — any thrown error is reported back to the CLI + }, + }, + ], +}); +``` + +When the user types `/deploy staging` in the CLI, the SDK receives a `command.execute` event, routes it to your handler, and automatically responds to the CLI. If the handler throws, the error message is forwarded. + +Commands are sent to the CLI on both `CreateSessionAsync` and `ResumeSessionAsync`, so you can update the command set when resuming. + +### UI Elicitation + +When the session has elicitation support — either from the CLI's TUI or from another client that registered an `OnElicitationRequest` handler (see [Elicitation Requests](#elicitation-requests)) — the SDK can request interactive form dialogs from the user. The `session.Ui` object provides convenience methods built on a single generic elicitation RPC. + +> **Capability check:** Elicitation is only available when at least one connected participant advertises support. Always check `session.Capabilities.Ui?.Elicitation` before calling UI methods — this property updates automatically as participants join and leave. + +```csharp +var session = await client.CreateSessionAsync(new SessionConfig +{ + Model = "gpt-5", + OnPermissionRequest = PermissionHandler.ApproveAll, +}); + +if (session.Capabilities.Ui?.Elicitation == true) +{ + // Confirm dialog — returns boolean + bool ok = await session.Ui.ConfirmAsync("Deploy to production?"); + + // Selection dialog — returns selected value or null + string? env = await session.Ui.SelectAsync("Pick environment", + ["production", "staging", "dev"]); + + // Text input — returns string or null + string? name = await session.Ui.InputAsync("Project name:", new InputOptions + { + Title = "Name", + MinLength = 1, + MaxLength = 50, + }); + + // Generic elicitation with full schema control + ElicitationResult result = await session.Ui.ElicitationAsync(new ElicitationParams + { + Message = "Configure deployment", + RequestedSchema = new ElicitationSchema + { + Type = "object", + Properties = new Dictionary + { + ["region"] = new Dictionary + { + ["type"] = "string", + ["enum"] = new[] { "us-east", "eu-west" }, + }, + ["dryRun"] = new Dictionary + { + ["type"] = "boolean", + ["default"] = true, + }, + }, + Required = ["region"], + }, + }); + // result.Action: Accept, Decline, or Cancel + // result.Content: { "region": "us-east", "dryRun": true } (when accepted) +} +``` + +All UI methods throw if elicitation is not supported by the host. + ### System Message Customization Control the system prompt using `SystemMessage` in session config: @@ -812,6 +901,49 @@ var session = await client.CreateSessionAsync(new SessionConfig - `OnSessionEnd` - Cleanup or logging when session ends. - `OnErrorOccurred` - Handle errors with retry/skip/abort strategies. +## Elicitation Requests + +Register an `OnElicitationRequest` handler to let your client act as an elicitation provider — presenting form-based UI dialogs on behalf of the agent. When provided, the server notifies your client whenever a tool or MCP server needs structured user input. + +```csharp +var session = await client.CreateSessionAsync(new SessionConfig +{ + Model = "gpt-5", + OnPermissionRequest = PermissionHandler.ApproveAll, + OnElicitationRequest = async (request, invocation) => + { + // request.Message - Description of what information is needed + // request.RequestedSchema - JSON Schema describing the form fields + // request.Mode - "form" (structured input) or "url" (browser redirect) + // request.ElicitationSource - Origin of the request (e.g. MCP server name) + + Console.WriteLine($"Elicitation from {request.ElicitationSource}: {request.Message}"); + + // Present UI to the user and collect their response... + return new ElicitationResult + { + Action = SessionUiElicitationResultAction.Accept, + Content = new Dictionary + { + ["region"] = "us-east", + ["dryRun"] = true, + }, + }; + }, +}); + +// The session now reports elicitation capability +Console.WriteLine(session.Capabilities.Ui?.Elicitation); // True +``` + +When `OnElicitationRequest` is provided, the SDK sends `RequestElicitation = true` during session create/resume, which enables `session.Capabilities.Ui.Elicitation` on the session. + +In multi-client scenarios: + +- If no connected client was previously providing an elicitation capability, but a new client joins that can, all clients will receive a `capabilities.changed` event to notify them that elicitation is now possible. The SDK automatically updates `session.Capabilities` when these events arrive. +- Similarly, if the last elicitation provider disconnects, all clients receive a `capabilities.changed` event indicating elicitation is no longer available. +- The server fans out elicitation requests to **all** connected clients that registered a handler — the first response wins. + ## Error Handling ```csharp diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index d1cea218e..ada241baa 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -456,6 +456,8 @@ public async Task CreateSessionAsync(SessionConfig config, Cance var session = new CopilotSession(sessionId, connection.Rpc, _logger); session.RegisterTools(config.Tools ?? []); session.RegisterPermissionHandler(config.OnPermissionRequest); + session.RegisterCommands(config.Commands); + session.RegisterElicitationHandler(config.OnElicitationRequest); if (config.OnUserInputRequest != null) { session.RegisterUserInputHandler(config.OnUserInputRequest); @@ -501,13 +503,16 @@ public async Task CreateSessionAsync(SessionConfig config, Cance config.SkillDirectories, config.DisabledSkills, config.InfiniteSessions, - traceparent, - tracestate); + Commands: config.Commands?.Select(c => new CommandWireDefinition(c.Name, c.Description)).ToList(), + RequestElicitation: config.OnElicitationRequest != null, + Traceparent: traceparent, + Tracestate: tracestate); var response = await InvokeRpcAsync( connection.Rpc, "session.create", [request], cancellationToken); session.WorkspacePath = response.WorkspacePath; + session.SetCapabilities(response.Capabilities); } catch { @@ -570,6 +575,8 @@ public async Task ResumeSessionAsync(string sessionId, ResumeSes var session = new CopilotSession(sessionId, connection.Rpc, _logger); session.RegisterTools(config.Tools ?? []); session.RegisterPermissionHandler(config.OnPermissionRequest); + session.RegisterCommands(config.Commands); + session.RegisterElicitationHandler(config.OnElicitationRequest); if (config.OnUserInputRequest != null) { session.RegisterUserInputHandler(config.OnUserInputRequest); @@ -616,13 +623,16 @@ public async Task ResumeSessionAsync(string sessionId, ResumeSes config.SkillDirectories, config.DisabledSkills, config.InfiniteSessions, - traceparent, - tracestate); + Commands: config.Commands?.Select(c => new CommandWireDefinition(c.Name, c.Description)).ToList(), + RequestElicitation: config.OnElicitationRequest != null, + Traceparent: traceparent, + Tracestate: tracestate); var response = await InvokeRpcAsync( connection.Rpc, "session.resume", [request], cancellationToken); session.WorkspacePath = response.WorkspacePath; + session.SetCapabilities(response.Capabilities); } catch { @@ -1592,6 +1602,8 @@ internal record CreateSessionRequest( List? SkillDirectories, List? DisabledSkills, InfiniteSessionConfig? InfiniteSessions, + List? Commands = null, + bool? RequestElicitation = null, string? Traceparent = null, string? Tracestate = null); @@ -1614,7 +1626,8 @@ public static ToolDefinition FromAIFunction(AIFunction function) internal record CreateSessionResponse( string SessionId, - string? WorkspacePath); + string? WorkspacePath, + SessionCapabilities? Capabilities = null); internal record ResumeSessionRequest( string SessionId, @@ -1640,12 +1653,19 @@ internal record ResumeSessionRequest( List? SkillDirectories, List? DisabledSkills, InfiniteSessionConfig? InfiniteSessions, + List? Commands = null, + bool? RequestElicitation = null, string? Traceparent = null, string? Tracestate = null); internal record ResumeSessionResponse( string SessionId, - string? WorkspacePath); + string? WorkspacePath, + SessionCapabilities? Capabilities = null); + + internal record CommandWireDefinition( + string Name, + string? Description); internal record GetLastSessionIdResponse( string? SessionId); @@ -1782,9 +1802,12 @@ private static LogLevel MapLevel(TraceEventType eventType) [JsonSerializable(typeof(ProviderConfig))] [JsonSerializable(typeof(ResumeSessionRequest))] [JsonSerializable(typeof(ResumeSessionResponse))] + [JsonSerializable(typeof(SessionCapabilities))] + [JsonSerializable(typeof(SessionUiCapabilities))] [JsonSerializable(typeof(SessionMetadata))] [JsonSerializable(typeof(SystemMessageConfig))] [JsonSerializable(typeof(SystemMessageTransformRpcResponse))] + [JsonSerializable(typeof(CommandWireDefinition))] [JsonSerializable(typeof(ToolCallResponseV2))] [JsonSerializable(typeof(ToolDefinition))] [JsonSerializable(typeof(ToolResultAIContent))] diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index 675a3e0c0..9697eefa3 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -56,11 +56,13 @@ namespace GitHub.Copilot.SDK; public sealed partial class CopilotSession : IAsyncDisposable { private readonly Dictionary _toolHandlers = []; + private readonly Dictionary _commandHandlers = []; private readonly JsonRpc _rpc; private readonly ILogger _logger; private volatile PermissionRequestHandler? _permissionHandler; private volatile UserInputHandler? _userInputHandler; + private volatile ElicitationHandler? _elicitationHandler; private ImmutableArray _eventHandlers = ImmutableArray.Empty; private SessionHooks? _hooks; @@ -98,6 +100,30 @@ public sealed partial class CopilotSession : IAsyncDisposable /// public string? WorkspacePath { get; internal set; } + /// + /// Gets the capabilities reported by the host for this session. + /// + /// + /// A object describing what the host supports. + /// Capabilities are populated from the session create/resume response and updated + /// in real time via capabilities.changed events. + /// + public SessionCapabilities Capabilities { get; private set; } = new(); + + /// + /// Gets the UI API for eliciting information from the user during this session. + /// + /// + /// An implementation with convenience methods for + /// confirm, select, input, and custom elicitation dialogs. + /// + /// + /// All methods on this property throw + /// if the host does not report elicitation support via . + /// Check session.Capabilities.Ui?.Elicitation == true before calling. + /// + public ISessionUiApi Ui => new SessionUiApiImpl(this); + /// /// Initializes a new instance of the class. /// @@ -436,6 +462,59 @@ private async Task HandleBroadcastEventAsync(SessionEvent sessionEvent) await ExecutePermissionAndRespondAsync(data.RequestId, data.PermissionRequest, handler); break; } + + case CommandExecuteEvent cmdEvent: + { + var data = cmdEvent.Data; + if (string.IsNullOrEmpty(data.RequestId)) + return; + + _ = ExecuteCommandAndRespondAsync(data.RequestId, data.CommandName, data.Command, data.Args); + break; + } + + case ElicitationRequestedEvent elicitEvent: + { + var data = elicitEvent.Data; + if (string.IsNullOrEmpty(data.RequestId)) + return; + + if (_elicitationHandler is not null) + { + var schema = data.RequestedSchema is not null + ? new ElicitationSchema + { + Type = data.RequestedSchema.Type, + Properties = data.RequestedSchema.Properties, + Required = data.RequestedSchema.Required?.ToList() + } + : null; + + _ = HandleElicitationRequestAsync( + new ElicitationRequest + { + Message = data.Message, + RequestedSchema = schema, + Mode = data.Mode?.ToString().ToLowerInvariant(), + ElicitationSource = data.ElicitationSource, + Url = data.Url + }, + data.RequestId); + } + break; + } + + case CapabilitiesChangedEvent capEvent: + { + var data = capEvent.Data; + Capabilities = new SessionCapabilities + { + Ui = data.Ui is not null + ? new SessionUiCapabilities { Elicitation = data.Ui.Elicitation } + : Capabilities.Ui + }; + break; + } } } catch (Exception ex) when (ex is not OperationCanceledException) @@ -557,6 +636,238 @@ internal void RegisterUserInputHandler(UserInputHandler handler) _userInputHandler = handler; } + /// + /// Registers command handlers for this session. + /// + /// The command definitions to register. + internal void RegisterCommands(IEnumerable? commands) + { + _commandHandlers.Clear(); + if (commands is null) return; + foreach (var cmd in commands) + { + _commandHandlers[cmd.Name] = cmd.Handler; + } + } + + /// + /// Registers an elicitation handler for this session. + /// + /// The handler to invoke when an elicitation request is received. + internal void RegisterElicitationHandler(ElicitationHandler? handler) + { + _elicitationHandler = handler; + } + + /// + /// Sets the capabilities reported by the host for this session. + /// + /// The capabilities to set. + internal void SetCapabilities(SessionCapabilities? capabilities) + { + Capabilities = capabilities ?? new SessionCapabilities(); + } + + /// + /// Dispatches a command.execute event to the registered handler and + /// responds via the commands.handlePendingCommand RPC. + /// + private async Task ExecuteCommandAndRespondAsync(string requestId, string commandName, string command, string args) + { + if (!_commandHandlers.TryGetValue(commandName, out var handler)) + { + try + { + await Rpc.Commands.HandlePendingCommandAsync(requestId, error: $"Unknown command: {commandName}"); + } + catch (Exception ex) when (ex is IOException or ObjectDisposedException) + { + // Connection lost — nothing we can do + } + return; + } + + try + { + await handler(new CommandContext + { + SessionId = SessionId, + Command = command, + CommandName = commandName, + Args = args + }); + await Rpc.Commands.HandlePendingCommandAsync(requestId); + } + catch (Exception error) when (error is not OperationCanceledException) + { + // User handler can throw any exception — report the error back to the server + // so the pending command doesn't hang. + var message = error.Message; + try + { + await Rpc.Commands.HandlePendingCommandAsync(requestId, error: message); + } + catch (Exception ex) when (ex is IOException or ObjectDisposedException) + { + // Connection lost — nothing we can do + } + } + } + + /// + /// Dispatches an elicitation.requested event to the registered handler and + /// responds via the ui.handlePendingElicitation RPC. Auto-cancels on handler errors. + /// + private async Task HandleElicitationRequestAsync(ElicitationRequest request, string requestId) + { + var handler = _elicitationHandler; + if (handler is null) return; + + try + { + var result = await handler(request, new ElicitationInvocation { SessionId = SessionId }); + await Rpc.Ui.HandlePendingElicitationAsync(requestId, new SessionUiHandlePendingElicitationRequestResult + { + Action = result.Action, + Content = result.Content + }); + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + // User handler can throw any exception — attempt to cancel so the request doesn't hang. + try + { + await Rpc.Ui.HandlePendingElicitationAsync(requestId, new SessionUiHandlePendingElicitationRequestResult + { + Action = SessionUiElicitationResultAction.Cancel + }); + } + catch (Exception innerEx) when (innerEx is IOException or ObjectDisposedException) + { + // Connection lost — nothing we can do + } + } + } + + /// + /// Throws if the host does not support elicitation. + /// + private void AssertElicitation() + { + if (Capabilities.Ui?.Elicitation != true) + { + throw new InvalidOperationException( + "Elicitation is not supported by the host. " + + "Check session.Capabilities.Ui?.Elicitation before calling UI methods."); + } + } + + /// + /// Implements backed by the session's RPC connection. + /// + private sealed class SessionUiApiImpl(CopilotSession session) : ISessionUiApi + { + public async Task ElicitationAsync(ElicitationParams elicitationParams, CancellationToken cancellationToken) + { + session.AssertElicitation(); + var schema = new SessionUiElicitationRequestRequestedSchema + { + Type = elicitationParams.RequestedSchema.Type, + Properties = elicitationParams.RequestedSchema.Properties, + Required = elicitationParams.RequestedSchema.Required + }; + var result = await session.Rpc.Ui.ElicitationAsync(elicitationParams.Message, schema, cancellationToken); + return new ElicitationResult { Action = result.Action, Content = result.Content }; + } + + public async Task ConfirmAsync(string message, CancellationToken cancellationToken) + { + session.AssertElicitation(); + var schema = new SessionUiElicitationRequestRequestedSchema + { + Type = "object", + Properties = new Dictionary + { + ["confirmed"] = new Dictionary { ["type"] = "boolean", ["default"] = true } + }, + Required = ["confirmed"] + }; + var result = await session.Rpc.Ui.ElicitationAsync(message, schema, cancellationToken); + if (result.Action == SessionUiElicitationResultAction.Accept + && result.Content != null + && result.Content.TryGetValue("confirmed", out var val)) + { + return val switch + { + bool b => b, + JsonElement { ValueKind: JsonValueKind.True } => true, + JsonElement { ValueKind: JsonValueKind.False } => false, + _ => false + }; + } + return false; + } + + public async Task SelectAsync(string message, string[] options, CancellationToken cancellationToken) + { + session.AssertElicitation(); + var schema = new SessionUiElicitationRequestRequestedSchema + { + Type = "object", + Properties = new Dictionary + { + ["selection"] = new Dictionary { ["type"] = "string", ["enum"] = options } + }, + Required = ["selection"] + }; + var result = await session.Rpc.Ui.ElicitationAsync(message, schema, cancellationToken); + if (result.Action == SessionUiElicitationResultAction.Accept + && result.Content != null + && result.Content.TryGetValue("selection", out var val)) + { + return val switch + { + string s => s, + JsonElement { ValueKind: JsonValueKind.String } je => je.GetString(), + _ => val.ToString() + }; + } + return null; + } + + public async Task InputAsync(string message, InputOptions? options, CancellationToken cancellationToken) + { + session.AssertElicitation(); + var field = new Dictionary { ["type"] = "string" }; + if (options?.Title != null) field["title"] = options.Title; + if (options?.Description != null) field["description"] = options.Description; + if (options?.MinLength != null) field["minLength"] = options.MinLength; + if (options?.MaxLength != null) field["maxLength"] = options.MaxLength; + if (options?.Format != null) field["format"] = options.Format; + if (options?.Default != null) field["default"] = options.Default; + + var schema = new SessionUiElicitationRequestRequestedSchema + { + Type = "object", + Properties = new Dictionary { ["value"] = field }, + Required = ["value"] + }; + var result = await session.Rpc.Ui.ElicitationAsync(message, schema, cancellationToken); + if (result.Action == SessionUiElicitationResultAction.Accept + && result.Content != null + && result.Content.TryGetValue("value", out var val)) + { + return val switch + { + string s => s, + JsonElement { ValueKind: JsonValueKind.String } je => je.GetString(), + _ => val.ToString() + }; + } + return null; + } + } + /// /// Handles a user input request from the Copilot CLI. /// @@ -890,8 +1201,10 @@ await InvokeRpcAsync( _eventHandlers = ImmutableInterlocked.InterlockedExchange(ref _eventHandlers, ImmutableArray.Empty); _toolHandlers.Clear(); + _commandHandlers.Clear(); _permissionHandler = null; + _elicitationHandler = null; } [LoggerMessage(Level = LogLevel.Error, Message = "Unhandled exception in broadcast event handler")] diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index d6530f9c7..6ed8d7d39 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -7,6 +7,7 @@ using System.Diagnostics.CodeAnalysis; using System.Text.Json; using System.Text.Json.Serialization; +using GitHub.Copilot.SDK.Rpc; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; @@ -500,6 +501,260 @@ public class UserInputInvocation /// public delegate Task UserInputHandler(UserInputRequest request, UserInputInvocation invocation); +// ============================================================================ +// Command Handler Types +// ============================================================================ + +/// +/// Defines a slash-command that users can invoke from the CLI TUI. +/// +public class CommandDefinition +{ + /// + /// Command name (without leading /). For example, "deploy". + /// + public required string Name { get; set; } + + /// + /// Human-readable description shown in the command completion UI. + /// + public string? Description { get; set; } + + /// + /// Handler invoked when the command is executed. + /// + public required CommandHandler Handler { get; set; } +} + +/// +/// Context passed to a when a command is executed. +/// +public class CommandContext +{ + /// + /// Session ID where the command was invoked. + /// + public string SessionId { get; set; } = string.Empty; + + /// + /// The full command text (e.g., /deploy production). + /// + public string Command { get; set; } = string.Empty; + + /// + /// Command name without leading /. + /// + public string CommandName { get; set; } = string.Empty; + + /// + /// Raw argument string after the command name. + /// + public string Args { get; set; } = string.Empty; +} + +/// +/// Delegate for handling slash-command executions. +/// +public delegate Task CommandHandler(CommandContext context); + +// ============================================================================ +// Elicitation Types (UI — client → server) +// ============================================================================ + +/// +/// JSON Schema describing the form fields to present for an elicitation dialog. +/// +public class ElicitationSchema +{ + /// + /// Schema type indicator (always "object"). + /// + [JsonPropertyName("type")] + public string Type { get; set; } = "object"; + + /// + /// Form field definitions, keyed by field name. + /// + [JsonPropertyName("properties")] + public Dictionary Properties { get; set; } = []; + + /// + /// List of required field names. + /// + [JsonPropertyName("required")] + public List? Required { get; set; } +} + +/// +/// Parameters for an elicitation request sent from the SDK to the server. +/// +public class ElicitationParams +{ + /// + /// Message describing what information is needed from the user. + /// + public required string Message { get; set; } + + /// + /// JSON Schema describing the form fields to present. + /// + public required ElicitationSchema RequestedSchema { get; set; } +} + +/// +/// Result returned from an elicitation dialog. +/// +public class ElicitationResult +{ + /// + /// User action: "accept" (submitted), "decline" (rejected), or "cancel" (dismissed). + /// + public SessionUiElicitationResultAction Action { get; set; } + + /// + /// Form values submitted by the user (present when is Accept). + /// + public Dictionary? Content { get; set; } +} + +/// +/// Options for the convenience method. +/// +public class InputOptions +{ + /// Title label for the input field. + public string? Title { get; set; } + + /// Descriptive text shown below the field. + public string? Description { get; set; } + + /// Minimum character length. + public int? MinLength { get; set; } + + /// Maximum character length. + public int? MaxLength { get; set; } + + /// Semantic format hint (e.g., "email", "uri", "date", "date-time"). + public string? Format { get; set; } + + /// Default value pre-populated in the field. + public string? Default { get; set; } +} + +/// +/// Provides UI methods for eliciting information from the user during a session. +/// +public interface ISessionUiApi +{ + /// + /// Shows a generic elicitation dialog with a custom schema. + /// + /// The elicitation parameters including message and schema. + /// Optional cancellation token. + /// The with the user's response. + /// Thrown if the host does not support elicitation. + Task ElicitationAsync(ElicitationParams elicitationParams, CancellationToken cancellationToken = default); + + /// + /// Shows a confirmation dialog and returns the user's boolean answer. + /// Returns false if the user declines or cancels. + /// + /// The message to display. + /// Optional cancellation token. + /// true if the user confirmed; otherwise false. + /// Thrown if the host does not support elicitation. + Task ConfirmAsync(string message, CancellationToken cancellationToken = default); + + /// + /// Shows a selection dialog with the given options. + /// Returns the selected value, or null if the user declines/cancels. + /// + /// The message to display. + /// The options to present. + /// Optional cancellation token. + /// The selected string, or null if the user declined/cancelled. + /// Thrown if the host does not support elicitation. + Task SelectAsync(string message, string[] options, CancellationToken cancellationToken = default); + + /// + /// Shows a text input dialog. + /// Returns the entered text, or null if the user declines/cancels. + /// + /// The message to display. + /// Optional input field options. + /// Optional cancellation token. + /// The entered string, or null if the user declined/cancelled. + /// Thrown if the host does not support elicitation. + Task InputAsync(string message, InputOptions? options = null, CancellationToken cancellationToken = default); +} + +// ============================================================================ +// Elicitation Types (server → client callback) +// ============================================================================ + +/// +/// An elicitation request received from the server. +/// +public class ElicitationRequest +{ + /// Message describing what information is needed from the user. + public string Message { get; set; } = string.Empty; + + /// JSON Schema describing the form fields to present. + public ElicitationSchema? RequestedSchema { get; set; } + + /// Elicitation mode: "form" for structured input, "url" for browser redirect. + public string? Mode { get; set; } + + /// The source that initiated the request (e.g., MCP server name). + public string? ElicitationSource { get; set; } + + /// URL to open in the user's browser (url mode only). + public string? Url { get; set; } +} + +/// +/// Context for an elicitation handler invocation. +/// +public class ElicitationInvocation +{ + /// + /// Identifier of the session that triggered the elicitation request. + /// + public string SessionId { get; set; } = string.Empty; +} + +/// +/// Delegate for handling elicitation requests from the server. +/// +public delegate Task ElicitationHandler(ElicitationRequest request, ElicitationInvocation invocation); + +// ============================================================================ +// Session Capabilities +// ============================================================================ + +/// +/// Represents the capabilities reported by the host for a session. +/// +public class SessionCapabilities +{ + /// + /// UI-related capabilities. + /// + public SessionUiCapabilities? Ui { get; set; } +} + +/// +/// UI-specific capability flags for a session. +/// +public class SessionUiCapabilities +{ + /// + /// Whether the host supports interactive elicitation dialogs. + /// + public bool? Elicitation { get; set; } +} + // ============================================================================ // Hook Handler Types // ============================================================================ @@ -1319,6 +1574,7 @@ protected SessionConfig(SessionConfig? other) AvailableTools = other.AvailableTools is not null ? [.. other.AvailableTools] : null; ClientName = other.ClientName; + Commands = other.Commands is not null ? [.. other.Commands] : null; ConfigDir = other.ConfigDir; CustomAgents = other.CustomAgents is not null ? [.. other.CustomAgents] : null; Agent = other.Agent; @@ -1330,6 +1586,7 @@ protected SessionConfig(SessionConfig? other) ? new Dictionary(other.McpServers, other.McpServers.Comparer) : null; Model = other.Model; + OnElicitationRequest = other.OnElicitationRequest; OnEvent = other.OnEvent; OnPermissionRequest = other.OnPermissionRequest; OnUserInputRequest = other.OnUserInputRequest; @@ -1405,6 +1662,20 @@ protected SessionConfig(SessionConfig? other) /// public UserInputHandler? OnUserInputRequest { get; set; } + /// + /// Slash commands registered for this session. + /// When the CLI has a TUI, each command appears as /name for the user to invoke. + /// The handler is called when the user executes the command. + /// + public List? Commands { get; set; } + + /// + /// Handler for elicitation requests from the server or MCP tools. + /// When provided, the server will route elicitation requests to this handler + /// and report elicitation as a supported capability. + /// + public ElicitationHandler? OnElicitationRequest { get; set; } + /// /// Hook handlers for session lifecycle events. /// @@ -1503,6 +1774,7 @@ protected ResumeSessionConfig(ResumeSessionConfig? other) AvailableTools = other.AvailableTools is not null ? [.. other.AvailableTools] : null; ClientName = other.ClientName; + Commands = other.Commands is not null ? [.. other.Commands] : null; ConfigDir = other.ConfigDir; CustomAgents = other.CustomAgents is not null ? [.. other.CustomAgents] : null; Agent = other.Agent; @@ -1515,6 +1787,7 @@ protected ResumeSessionConfig(ResumeSessionConfig? other) ? new Dictionary(other.McpServers, other.McpServers.Comparer) : null; Model = other.Model; + OnElicitationRequest = other.OnElicitationRequest; OnEvent = other.OnEvent; OnPermissionRequest = other.OnPermissionRequest; OnUserInputRequest = other.OnUserInputRequest; @@ -1583,6 +1856,20 @@ protected ResumeSessionConfig(ResumeSessionConfig? other) /// public UserInputHandler? OnUserInputRequest { get; set; } + /// + /// Slash commands registered for this session. + /// When the CLI has a TUI, each command appears as /name for the user to invoke. + /// The handler is called when the user executes the command. + /// + public List? Commands { get; set; } + + /// + /// Handler for elicitation requests from the server or MCP tools. + /// When provided, the server will route elicitation requests to this handler + /// and report elicitation as a supported capability. + /// + public ElicitationHandler? OnElicitationRequest { get; set; } + /// /// Hook handlers for session lifecycle events. /// diff --git a/dotnet/test/CommandsTests.cs b/dotnet/test/CommandsTests.cs new file mode 100644 index 000000000..183e87bc8 --- /dev/null +++ b/dotnet/test/CommandsTests.cs @@ -0,0 +1,138 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using GitHub.Copilot.SDK.Test.Harness; +using Xunit; +using Xunit.Abstractions; + +namespace GitHub.Copilot.SDK.Test; + +public class CommandsTests(E2ETestFixture fixture, ITestOutputHelper output) + : E2ETestBase(fixture, "commands", output) +{ + [Fact] + public async Task Forwards_Commands_In_Session_Create() + { + var session = await CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Commands = + [ + new CommandDefinition { Name = "deploy", Description = "Deploy the app", Handler = _ => Task.CompletedTask }, + new CommandDefinition { Name = "rollback", Handler = _ => Task.CompletedTask }, + ], + }); + + // Session should be created successfully with commands + Assert.NotNull(session); + Assert.NotNull(session.SessionId); + await session.DisposeAsync(); + } + + [Fact] + public async Task Forwards_Commands_In_Session_Resume() + { + var session1 = await CreateSessionAsync(); + var sessionId = session1.SessionId; + + var session2 = await ResumeSessionAsync(sessionId, new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Commands = + [ + new CommandDefinition { Name = "deploy", Description = "Deploy", Handler = _ => Task.CompletedTask }, + ], + }); + + Assert.NotNull(session2); + Assert.Equal(sessionId, session2.SessionId); + await session2.DisposeAsync(); + } + + [Fact] + public void CommandDefinition_Has_Required_Properties() + { + var cmd = new CommandDefinition + { + Name = "deploy", + Description = "Deploy the app", + Handler = _ => Task.CompletedTask, + }; + + Assert.Equal("deploy", cmd.Name); + Assert.Equal("Deploy the app", cmd.Description); + Assert.NotNull(cmd.Handler); + } + + [Fact] + public void CommandContext_Has_All_Properties() + { + var ctx = new CommandContext + { + SessionId = "session-1", + Command = "/deploy production", + CommandName = "deploy", + Args = "production", + }; + + Assert.Equal("session-1", ctx.SessionId); + Assert.Equal("/deploy production", ctx.Command); + Assert.Equal("deploy", ctx.CommandName); + Assert.Equal("production", ctx.Args); + } + + [Fact] + public async Task Session_With_No_Commands_Creates_Successfully() + { + var session = await CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + }); + + Assert.NotNull(session); + await session.DisposeAsync(); + } + + [Fact] + public async Task Session_Config_Commands_Are_Cloned() + { + var config = new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Commands = + [ + new CommandDefinition { Name = "deploy", Handler = _ => Task.CompletedTask }, + ], + }; + + var clone = config.Clone(); + + Assert.NotNull(clone.Commands); + Assert.Single(clone.Commands!); + Assert.Equal("deploy", clone.Commands![0].Name); + + // Verify collections are independent + clone.Commands!.Add(new CommandDefinition { Name = "rollback", Handler = _ => Task.CompletedTask }); + Assert.Single(config.Commands!); + } + + [Fact] + public void Resume_Config_Commands_Are_Cloned() + { + var config = new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Commands = + [ + new CommandDefinition { Name = "deploy", Handler = _ => Task.CompletedTask }, + ], + }; + + var clone = config.Clone(); + + Assert.NotNull(clone.Commands); + Assert.Single(clone.Commands!); + Assert.Equal("deploy", clone.Commands![0].Name); + } +} diff --git a/dotnet/test/ElicitationTests.cs b/dotnet/test/ElicitationTests.cs new file mode 100644 index 000000000..484ff3a28 --- /dev/null +++ b/dotnet/test/ElicitationTests.cs @@ -0,0 +1,306 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using GitHub.Copilot.SDK.Rpc; +using GitHub.Copilot.SDK.Test.Harness; +using Xunit; +using Xunit.Abstractions; + +namespace GitHub.Copilot.SDK.Test; + +public class ElicitationTests(E2ETestFixture fixture, ITestOutputHelper output) + : E2ETestBase(fixture, "elicitation", output) +{ + [Fact] + public async Task Defaults_Capabilities_When_Not_Provided() + { + var session = await CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + }); + + // Default capabilities should exist (even if empty) + Assert.NotNull(session.Capabilities); + await session.DisposeAsync(); + } + + [Fact] + public async Task Elicitation_Throws_When_Capability_Is_Missing() + { + var session = await CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + }); + + // Capabilities.Ui?.Elicitation should not be true by default (headless mode) + Assert.True(session.Capabilities.Ui?.Elicitation != true); + + // Calling any UI method should throw + var ex = await Assert.ThrowsAsync(async () => + { + await session.Ui.ConfirmAsync("test"); + }); + Assert.Contains("not supported", ex.Message, StringComparison.OrdinalIgnoreCase); + + ex = await Assert.ThrowsAsync(async () => + { + await session.Ui.SelectAsync("test", ["a", "b"]); + }); + Assert.Contains("not supported", ex.Message, StringComparison.OrdinalIgnoreCase); + + ex = await Assert.ThrowsAsync(async () => + { + await session.Ui.InputAsync("test"); + }); + Assert.Contains("not supported", ex.Message, StringComparison.OrdinalIgnoreCase); + + ex = await Assert.ThrowsAsync(async () => + { + await session.Ui.ElicitationAsync(new ElicitationParams + { + Message = "Enter name", + RequestedSchema = new ElicitationSchema + { + Properties = new() { ["name"] = new Dictionary { ["type"] = "string" } }, + Required = ["name"], + }, + }); + }); + Assert.Contains("not supported", ex.Message, StringComparison.OrdinalIgnoreCase); + + await session.DisposeAsync(); + } + + [Fact] + public async Task Sends_RequestElicitation_When_Handler_Provided() + { + var session = await CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnElicitationRequest = (_, _) => Task.FromResult(new ElicitationResult + { + Action = SessionUiElicitationResultAction.Accept, + Content = new Dictionary(), + }), + }); + + // Session should be created successfully with requestElicitation=true + Assert.NotNull(session); + Assert.NotNull(session.SessionId); + await session.DisposeAsync(); + } + + [Fact] + public async Task Session_With_ElicitationHandler_Reports_Elicitation_Capability() + { + var session = await CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnElicitationRequest = (_, _) => Task.FromResult(new ElicitationResult + { + Action = SessionUiElicitationResultAction.Accept, + Content = new Dictionary(), + }), + }); + + Assert.True(session.Capabilities.Ui?.Elicitation == true, + "Session with onElicitationRequest should report elicitation capability"); + await session.DisposeAsync(); + } + + [Fact] + public async Task Session_Without_ElicitationHandler_Reports_No_Capability() + { + var session = await CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + }); + + Assert.True(session.Capabilities.Ui?.Elicitation != true, + "Session without onElicitationRequest should not report elicitation capability"); + await session.DisposeAsync(); + } + + [Fact] + public async Task Session_Without_ElicitationHandler_Creates_Successfully() + { + var session = await CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + }); + + // requestElicitation was false (no handler) + Assert.NotNull(session); + await session.DisposeAsync(); + } + + [Fact] + public void SessionCapabilities_Types_Are_Properly_Structured() + { + var capabilities = new SessionCapabilities + { + Ui = new SessionUiCapabilities { Elicitation = true } + }; + + Assert.NotNull(capabilities.Ui); + Assert.True(capabilities.Ui.Elicitation); + + // Test with null UI + var emptyCapabilities = new SessionCapabilities(); + Assert.Null(emptyCapabilities.Ui); + } + + [Fact] + public void ElicitationSchema_Types_Are_Properly_Structured() + { + var schema = new ElicitationSchema + { + Type = "object", + Properties = new Dictionary + { + ["name"] = new Dictionary { ["type"] = "string", ["minLength"] = 1 }, + ["confirmed"] = new Dictionary { ["type"] = "boolean", ["default"] = true }, + }, + Required = ["name"], + }; + + Assert.Equal("object", schema.Type); + Assert.Equal(2, schema.Properties.Count); + Assert.Single(schema.Required!); + } + + [Fact] + public void ElicitationParams_Types_Are_Properly_Structured() + { + var ep = new ElicitationParams + { + Message = "Enter your name", + RequestedSchema = new ElicitationSchema + { + Properties = new Dictionary + { + ["name"] = new Dictionary { ["type"] = "string" }, + }, + }, + }; + + Assert.Equal("Enter your name", ep.Message); + Assert.NotNull(ep.RequestedSchema); + } + + [Fact] + public void ElicitationResult_Types_Are_Properly_Structured() + { + var result = new ElicitationResult + { + Action = SessionUiElicitationResultAction.Accept, + Content = new Dictionary { ["name"] = "Alice" }, + }; + + Assert.Equal(SessionUiElicitationResultAction.Accept, result.Action); + Assert.NotNull(result.Content); + Assert.Equal("Alice", result.Content!["name"]); + + var declined = new ElicitationResult + { + Action = SessionUiElicitationResultAction.Decline, + }; + Assert.Null(declined.Content); + } + + [Fact] + public void InputOptions_Has_All_Properties() + { + var options = new InputOptions + { + Title = "Email Address", + Description = "Enter your email", + MinLength = 5, + MaxLength = 100, + Format = "email", + Default = "user@example.com", + }; + + Assert.Equal("Email Address", options.Title); + Assert.Equal("Enter your email", options.Description); + Assert.Equal(5, options.MinLength); + Assert.Equal(100, options.MaxLength); + Assert.Equal("email", options.Format); + Assert.Equal("user@example.com", options.Default); + } + + [Fact] + public void ElicitationRequest_Has_All_Properties() + { + var request = new ElicitationRequest + { + Message = "Pick a color", + RequestedSchema = new ElicitationSchema + { + Properties = new Dictionary + { + ["color"] = new Dictionary { ["type"] = "string", ["enum"] = new[] { "red", "blue" } }, + }, + }, + Mode = "form", + ElicitationSource = "mcp-server", + Url = null, + }; + + Assert.Equal("Pick a color", request.Message); + Assert.NotNull(request.RequestedSchema); + Assert.Equal("form", request.Mode); + Assert.Equal("mcp-server", request.ElicitationSource); + Assert.Null(request.Url); + } + + [Fact] + public void ElicitationInvocation_Has_SessionId() + { + var invocation = new ElicitationInvocation + { + SessionId = "session-42" + }; + + Assert.Equal("session-42", invocation.SessionId); + } + + [Fact] + public async Task Session_Config_OnElicitationRequest_Is_Cloned() + { + ElicitationHandler handler = (_, _) => Task.FromResult(new ElicitationResult + { + Action = SessionUiElicitationResultAction.Cancel, + }); + + var config = new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnElicitationRequest = handler, + }; + + var clone = config.Clone(); + + Assert.Same(handler, clone.OnElicitationRequest); + } + + [Fact] + public void Resume_Config_OnElicitationRequest_Is_Cloned() + { + ElicitationHandler handler = (_, _) => Task.FromResult(new ElicitationResult + { + Action = SessionUiElicitationResultAction.Cancel, + }); + + var config = new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnElicitationRequest = handler, + }; + + var clone = config.Clone(); + + Assert.Same(handler, clone.OnElicitationRequest); + } +} diff --git a/dotnet/test/MultiClientCommandsElicitationTests.cs b/dotnet/test/MultiClientCommandsElicitationTests.cs new file mode 100644 index 000000000..6c5645613 --- /dev/null +++ b/dotnet/test/MultiClientCommandsElicitationTests.cs @@ -0,0 +1,261 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Reflection; +using GitHub.Copilot.SDK.Test.Harness; +using Xunit; +using Xunit.Abstractions; + +namespace GitHub.Copilot.SDK.Test; + +/// +/// Custom fixture for multi-client commands/elicitation tests. +/// Uses TCP mode so a second (and third) client can connect to the same CLI process. +/// +public class MultiClientCommandsElicitationFixture : IAsyncLifetime +{ + public E2ETestContext Ctx { get; private set; } = null!; + public CopilotClient Client1 { get; private set; } = null!; + + public async Task InitializeAsync() + { + Ctx = await E2ETestContext.CreateAsync(); + Client1 = Ctx.CreateClient(useStdio: false); + } + + public async Task DisposeAsync() + { + if (Client1 is not null) + { + await Client1.ForceStopAsync(); + } + + await Ctx.DisposeAsync(); + } +} + +public class MultiClientCommandsElicitationTests + : IClassFixture, IAsyncLifetime +{ + private readonly MultiClientCommandsElicitationFixture _fixture; + private readonly string _testName; + private CopilotClient? _client2; + private CopilotClient? _client3; + + private E2ETestContext Ctx => _fixture.Ctx; + private CopilotClient Client1 => _fixture.Client1; + + public MultiClientCommandsElicitationTests( + MultiClientCommandsElicitationFixture fixture, + ITestOutputHelper output) + { + _fixture = fixture; + _testName = GetTestName(output); + } + + private static string GetTestName(ITestOutputHelper output) + { + var type = output.GetType(); + var testField = type.GetField("test", BindingFlags.Instance | BindingFlags.NonPublic); + var test = (ITest?)testField?.GetValue(output); + return test?.TestCase.TestMethod.Method.Name + ?? throw new InvalidOperationException("Couldn't find test name"); + } + + public async Task InitializeAsync() + { + await Ctx.ConfigureForTestAsync("multi_client", _testName); + + // Trigger connection so we can read the port + var initSession = await Client1.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + }); + await initSession.DisposeAsync(); + + var port = Client1.ActualPort + ?? throw new InvalidOperationException("Client1 is not using TCP mode; ActualPort is null"); + + _client2 = new CopilotClient(new CopilotClientOptions + { + CliUrl = $"localhost:{port}", + }); + } + + public async Task DisposeAsync() + { + if (_client3 is not null) + { + await _client3.ForceStopAsync(); + _client3 = null; + } + + if (_client2 is not null) + { + await _client2.ForceStopAsync(); + _client2 = null; + } + } + + private CopilotClient Client2 => _client2 + ?? throw new InvalidOperationException("Client2 not initialized"); + + [Fact] + public async Task Client_Receives_Commands_Changed_When_Another_Client_Joins_With_Commands() + { + var session1 = await Client1.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + }); + + // Wait for the commands.changed event deterministically + var commandsChangedTcs = new TaskCompletionSource( + TaskCreationOptions.RunContinuationsAsynchronously); + + using var sub = session1.On(evt => + { + if (evt is CommandsChangedEvent changed) + { + commandsChangedTcs.TrySetResult(changed); + } + }); + + // Client2 joins with commands + var session2 = await Client2.ResumeSessionAsync(session1.SessionId, new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Commands = + [ + new CommandDefinition + { + Name = "deploy", + Description = "Deploy the app", + Handler = _ => Task.CompletedTask, + }, + ], + DisableResume = true, + }); + + var commandsChanged = await commandsChangedTcs.Task.WaitAsync(TimeSpan.FromSeconds(15)); + + Assert.NotNull(commandsChanged.Data.Commands); + Assert.Contains(commandsChanged.Data.Commands, c => + c.Name == "deploy" && c.Description == "Deploy the app"); + + await session2.DisposeAsync(); + } + + [Fact] + public async Task Capabilities_Changed_Fires_When_Second_Client_Joins_With_Elicitation_Handler() + { + // Client1 creates session without elicitation + var session1 = await Client1.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + }); + Assert.True(session1.Capabilities.Ui?.Elicitation != true, + "Session without elicitation handler should not have elicitation capability"); + + // Listen for capabilities.changed event + var capChangedTcs = new TaskCompletionSource( + TaskCreationOptions.RunContinuationsAsynchronously); + + using var sub = session1.On(evt => + { + if (evt is CapabilitiesChangedEvent capEvt) + { + capChangedTcs.TrySetResult(capEvt); + } + }); + + // Client2 joins WITH elicitation handler — triggers capabilities.changed + var session2 = await Client2.ResumeSessionAsync(session1.SessionId, new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnElicitationRequest = (_, _) => Task.FromResult(new ElicitationResult + { + Action = Rpc.SessionUiElicitationResultAction.Accept, + Content = new Dictionary(), + }), + DisableResume = true, + }); + + var capEvent = await capChangedTcs.Task.WaitAsync(TimeSpan.FromSeconds(15)); + + Assert.NotNull(capEvent.Data.Ui); + Assert.True(capEvent.Data.Ui!.Elicitation); + + // Client1's capabilities should have been auto-updated + Assert.True(session1.Capabilities.Ui?.Elicitation == true); + + await session2.DisposeAsync(); + } + + [Fact] + public async Task Capabilities_Changed_Fires_When_Elicitation_Provider_Disconnects() + { + // Client1 creates session without elicitation + var session1 = await Client1.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + }); + Assert.True(session1.Capabilities.Ui?.Elicitation != true, + "Session without elicitation handler should not have elicitation capability"); + + // Wait for elicitation to become available + var capEnabledTcs = new TaskCompletionSource( + TaskCreationOptions.RunContinuationsAsynchronously); + + using var subEnabled = session1.On(evt => + { + if (evt is CapabilitiesChangedEvent { Data.Ui.Elicitation: true }) + { + capEnabledTcs.TrySetResult(true); + } + }); + + // Use a dedicated client (client3) so we can stop it without affecting client2 + var port = Client1.ActualPort + ?? throw new InvalidOperationException("Client1 ActualPort is null"); + _client3 = new CopilotClient(new CopilotClientOptions + { + CliUrl = $"localhost:{port}", + }); + + // Client3 joins WITH elicitation handler + await _client3.ResumeSessionAsync(session1.SessionId, new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnElicitationRequest = (_, _) => Task.FromResult(new ElicitationResult + { + Action = Rpc.SessionUiElicitationResultAction.Accept, + Content = new Dictionary(), + }), + DisableResume = true, + }); + + await capEnabledTcs.Task.WaitAsync(TimeSpan.FromSeconds(15)); + Assert.True(session1.Capabilities.Ui?.Elicitation == true); + + // Now listen for the capability being removed + var capDisabledTcs = new TaskCompletionSource( + TaskCreationOptions.RunContinuationsAsynchronously); + + using var subDisabled = session1.On(evt => + { + if (evt is CapabilitiesChangedEvent { Data.Ui.Elicitation: false }) + { + capDisabledTcs.TrySetResult(true); + } + }); + + // Force-stop client3 — destroys the socket, triggering server-side cleanup + await _client3.ForceStopAsync(); + _client3 = null; + + await capDisabledTcs.Task.WaitAsync(TimeSpan.FromSeconds(15)); + Assert.True(session1.Capabilities.Ui?.Elicitation != true, + "After elicitation provider disconnects, capability should be removed"); + } +} diff --git a/go/README.md b/go/README.md index f29ef9fb7..eebe4c4d3 100644 --- a/go/README.md +++ b/go/README.md @@ -160,6 +160,8 @@ Event types: `SessionLifecycleCreated`, `SessionLifecycleDeleted`, `SessionLifec - `OnPermissionRequest` (PermissionHandlerFunc): **Required.** Handler called before each tool execution to approve or deny it. Use `copilot.PermissionHandler.ApproveAll` to allow everything, or provide a custom function for fine-grained control. See [Permission Handling](#permission-handling) section. - `OnUserInputRequest` (UserInputHandler): Handler for user input requests from the agent (enables ask_user tool). See [User Input Requests](#user-input-requests) section. - `Hooks` (\*SessionHooks): Hook handlers for session lifecycle events. See [Session Hooks](#session-hooks) section. +- `Commands` ([]CommandDefinition): Slash-commands registered for this session. See [Commands](#commands) section. +- `OnElicitationRequest` (ElicitationHandler): Handler for elicitation requests from the server. See [Elicitation Requests](#elicitation-requests-serverclient) section. **ResumeSessionConfig:** @@ -168,6 +170,8 @@ Event types: `SessionLifecycleCreated`, `SessionLifecycleDeleted`, `SessionLifec - `ReasoningEffort` (string): Reasoning effort level for models that support it - `Provider` (\*ProviderConfig): Custom API provider configuration (BYOK). See [Custom Providers](#custom-providers) section. - `Streaming` (bool): Enable streaming delta events +- `Commands` ([]CommandDefinition): Slash-commands. See [Commands](#commands) section. +- `OnElicitationRequest` (ElicitationHandler): Elicitation handler. See [Elicitation Requests](#elicitation-requests-serverclient) section. ### Session @@ -177,10 +181,15 @@ Event types: `SessionLifecycleCreated`, `SessionLifecycleDeleted`, `SessionLifec - `GetMessages(ctx context.Context) ([]SessionEvent, error)` - Get message history - `Disconnect() error` - Disconnect the session (releases in-memory resources, preserves disk state) - `Destroy() error` - *(Deprecated)* Use `Disconnect()` instead +- `UI() *SessionUI` - Interactive UI API for elicitation dialogs +- `Capabilities() SessionCapabilities` - Host capabilities (e.g. elicitation support) ### Helper Functions - `Bool(v bool) *bool` - Helper to create bool pointers for `AutoStart` option +- `Int(v int) *int` - Helper to create int pointers for `MinLength`, `MaxLength` +- `String(v string) *string` - Helper to create string pointers +- `Float64(v float64) *float64` - Helper to create float64 pointers ### System Message Customization @@ -731,6 +740,109 @@ session, err := client.CreateSession(context.Background(), &copilot.SessionConfi - `OnSessionEnd` - Cleanup or logging when session ends. - `OnErrorOccurred` - Handle errors with retry/skip/abort strategies. +## Commands + +Register slash-commands that users can invoke from the CLI TUI. When a user types `/deploy production`, the SDK dispatches to your handler and responds via the RPC layer. + +```go +session, err := client.CreateSession(ctx, &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Commands: []copilot.CommandDefinition{ + { + Name: "deploy", + Description: "Deploy the app to production", + Handler: func(ctx copilot.CommandContext) error { + fmt.Printf("Deploying with args: %s\n", ctx.Args) + // ctx.SessionID, ctx.Command, ctx.CommandName, ctx.Args + return nil + }, + }, + { + Name: "rollback", + Description: "Rollback the last deployment", + Handler: func(ctx copilot.CommandContext) error { + return nil + }, + }, + }, +}) +``` + +Commands are also available when resuming sessions: + +```go +session, err := client.ResumeSession(ctx, sessionID, &copilot.ResumeSessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Commands: []copilot.CommandDefinition{ + {Name: "status", Description: "Show status", Handler: statusHandler}, + }, +}) +``` + +If a handler returns an error, the SDK sends the error message back to the server. Unknown commands automatically receive an error response. + +## UI Elicitation + +The SDK provides convenience methods to ask the user questions via elicitation dialogs. These are gated by host capabilities — check `session.Capabilities().UI.Elicitation` before calling. + +```go +ui := session.UI() + +// Confirmation dialog — returns bool +confirmed, err := ui.Confirm(ctx, "Deploy to production?") + +// Selection dialog — returns (selected string, ok bool, error) +choice, ok, err := ui.Select(ctx, "Pick an environment", []string{"staging", "production"}) + +// Text input — returns (text, ok bool, error) +name, ok, err := ui.Input(ctx, "Enter the release name", &copilot.InputOptions{ + Title: "Release Name", + Description: "A short name for the release", + MinLength: copilot.Int(1), + MaxLength: copilot.Int(50), +}) + +// Full custom elicitation with a schema +result, err := ui.Elicitation(ctx, "Configure deployment", rpc.RequestedSchema{ + Type: rpc.RequestedSchemaTypeObject, + Properties: map[string]rpc.Property{ + "target": {Type: rpc.PropertyTypeString, Enum: []string{"staging", "production"}}, + "force": {Type: rpc.PropertyTypeBoolean}, + }, + Required: []string{"target"}, +}) +// result.Action is "accept", "decline", or "cancel" +// result.Content has the form values when Action is "accept" +``` + +## Elicitation Requests (Server→Client) + +When the server (or an MCP tool) needs to ask the end-user a question, it sends an `elicitation.requested` event. Register a handler to respond: + +```go +session, err := client.CreateSession(ctx, &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + OnElicitationRequest: func(req copilot.ElicitationRequest, inv copilot.ElicitationInvocation) (copilot.ElicitationResult, error) { + // req.Message — what's being asked + // req.RequestedSchema — form schema (if mode is "form") + // req.Mode — "form" or "url" + // req.ElicitationSource — e.g. MCP server name + // req.URL — browser URL (if mode is "url") + + // Return the user's response + return copilot.ElicitationResult{ + Action: "accept", + Content: map[string]any{"confirmed": true}, + }, nil + }, +}) +``` + +When `OnElicitationRequest` is provided, the SDK automatically: +- Sends `requestElicitation: true` in the create/resume payload +- Routes `elicitation.requested` events to your handler +- Auto-cancels the request if your handler returns an error (so the server doesn't hang) + ## Transport Modes ### stdio (Default) diff --git a/go/client.go b/go/client.go index dbb5a3d8f..6f88c768a 100644 --- a/go/client.go +++ b/go/client.go @@ -556,6 +556,17 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses req.DisabledSkills = config.DisabledSkills req.InfiniteSessions = config.InfiniteSessions + if len(config.Commands) > 0 { + cmds := make([]wireCommand, 0, len(config.Commands)) + for _, cmd := range config.Commands { + cmds = append(cmds, wireCommand{Name: cmd.Name, Description: cmd.Description}) + } + req.Commands = cmds + } + if config.OnElicitationRequest != nil { + req.RequestElicitation = Bool(true) + } + if config.Streaming { req.Streaming = Bool(true) } @@ -600,6 +611,12 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses if config.OnEvent != nil { session.On(config.OnEvent) } + if len(config.Commands) > 0 { + session.registerCommands(config.Commands) + } + if config.OnElicitationRequest != nil { + session.registerElicitationHandler(config.OnElicitationRequest) + } c.sessionsMux.Lock() c.sessions[sessionID] = session @@ -622,6 +639,7 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses } session.workspacePath = response.WorkspacePath + session.setCapabilities(response.Capabilities) return session, nil } @@ -699,6 +717,17 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, req.InfiniteSessions = config.InfiniteSessions req.RequestPermission = Bool(true) + if len(config.Commands) > 0 { + cmds := make([]wireCommand, 0, len(config.Commands)) + for _, cmd := range config.Commands { + cmds = append(cmds, wireCommand{Name: cmd.Name, Description: cmd.Description}) + } + req.Commands = cmds + } + if config.OnElicitationRequest != nil { + req.RequestElicitation = Bool(true) + } + traceparent, tracestate := getTraceContext(ctx) req.Traceparent = traceparent req.Tracestate = tracestate @@ -721,6 +750,12 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, if config.OnEvent != nil { session.On(config.OnEvent) } + if len(config.Commands) > 0 { + session.registerCommands(config.Commands) + } + if config.OnElicitationRequest != nil { + session.registerElicitationHandler(config.OnElicitationRequest) + } c.sessionsMux.Lock() c.sessions[sessionID] = session @@ -743,6 +778,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, } session.workspacePath = response.WorkspacePath + session.setCapabilities(response.Capabilities) return session, nil } diff --git a/go/client_test.go b/go/client_test.go index d7a526cab..8f302f338 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -674,3 +674,175 @@ func TestClient_StartStopRace(t *testing.T) { t.Fatal(err) } } + +func TestCreateSessionRequest_Commands(t *testing.T) { + t.Run("forwards commands in session.create RPC", func(t *testing.T) { + req := createSessionRequest{ + Commands: []wireCommand{ + {Name: "deploy", Description: "Deploy the app"}, + {Name: "rollback", Description: "Rollback last deploy"}, + }, + } + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + cmds, ok := m["commands"].([]any) + if !ok { + t.Fatalf("Expected commands to be an array, got %T", m["commands"]) + } + if len(cmds) != 2 { + t.Fatalf("Expected 2 commands, got %d", len(cmds)) + } + cmd0 := cmds[0].(map[string]any) + if cmd0["name"] != "deploy" { + t.Errorf("Expected first command name 'deploy', got %v", cmd0["name"]) + } + if cmd0["description"] != "Deploy the app" { + t.Errorf("Expected first command description 'Deploy the app', got %v", cmd0["description"]) + } + }) + + t.Run("omits commands from JSON when empty", func(t *testing.T) { + req := createSessionRequest{} + data, _ := json.Marshal(req) + var m map[string]any + json.Unmarshal(data, &m) + if _, ok := m["commands"]; ok { + t.Error("Expected commands to be omitted when empty") + } + }) +} + +func TestResumeSessionRequest_Commands(t *testing.T) { + t.Run("forwards commands in session.resume RPC", func(t *testing.T) { + req := resumeSessionRequest{ + SessionID: "s1", + Commands: []wireCommand{ + {Name: "deploy", Description: "Deploy the app"}, + }, + } + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + cmds, ok := m["commands"].([]any) + if !ok { + t.Fatalf("Expected commands to be an array, got %T", m["commands"]) + } + if len(cmds) != 1 { + t.Fatalf("Expected 1 command, got %d", len(cmds)) + } + cmd0 := cmds[0].(map[string]any) + if cmd0["name"] != "deploy" { + t.Errorf("Expected command name 'deploy', got %v", cmd0["name"]) + } + }) + + t.Run("omits commands from JSON when empty", func(t *testing.T) { + req := resumeSessionRequest{SessionID: "s1"} + data, _ := json.Marshal(req) + var m map[string]any + json.Unmarshal(data, &m) + if _, ok := m["commands"]; ok { + t.Error("Expected commands to be omitted when empty") + } + }) +} + +func TestCreateSessionRequest_RequestElicitation(t *testing.T) { + t.Run("sends requestElicitation flag when OnElicitationRequest is provided", func(t *testing.T) { + req := createSessionRequest{ + RequestElicitation: Bool(true), + } + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + if m["requestElicitation"] != true { + t.Errorf("Expected requestElicitation to be true, got %v", m["requestElicitation"]) + } + }) + + t.Run("does not send requestElicitation when no handler provided", func(t *testing.T) { + req := createSessionRequest{} + data, _ := json.Marshal(req) + var m map[string]any + json.Unmarshal(data, &m) + if _, ok := m["requestElicitation"]; ok { + t.Error("Expected requestElicitation to be omitted when not set") + } + }) +} + +func TestResumeSessionRequest_RequestElicitation(t *testing.T) { + t.Run("sends requestElicitation flag when OnElicitationRequest is provided", func(t *testing.T) { + req := resumeSessionRequest{ + SessionID: "s1", + RequestElicitation: Bool(true), + } + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + if m["requestElicitation"] != true { + t.Errorf("Expected requestElicitation to be true, got %v", m["requestElicitation"]) + } + }) + + t.Run("does not send requestElicitation when no handler provided", func(t *testing.T) { + req := resumeSessionRequest{SessionID: "s1"} + data, _ := json.Marshal(req) + var m map[string]any + json.Unmarshal(data, &m) + if _, ok := m["requestElicitation"]; ok { + t.Error("Expected requestElicitation to be omitted when not set") + } + }) +} + +func TestCreateSessionResponse_Capabilities(t *testing.T) { + t.Run("reads capabilities from session.create response", func(t *testing.T) { + responseJSON := `{"sessionId":"s1","workspacePath":"/tmp","capabilities":{"ui":{"elicitation":true}}}` + var response createSessionResponse + if err := json.Unmarshal([]byte(responseJSON), &response); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + if response.Capabilities == nil { + t.Fatal("Expected capabilities to be non-nil") + } + if response.Capabilities.UI == nil { + t.Fatal("Expected capabilities.UI to be non-nil") + } + if !response.Capabilities.UI.Elicitation { + t.Errorf("Expected capabilities.UI.Elicitation to be true") + } + }) + + t.Run("defaults capabilities when not present", func(t *testing.T) { + responseJSON := `{"sessionId":"s1","workspacePath":"/tmp"}` + var response createSessionResponse + if err := json.Unmarshal([]byte(responseJSON), &response); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + if response.Capabilities != nil && response.Capabilities.UI != nil && response.Capabilities.UI.Elicitation { + t.Errorf("Expected capabilities.UI.Elicitation to be falsy when not injected") + } + }) +} diff --git a/go/internal/e2e/commands_and_elicitation_test.go b/go/internal/e2e/commands_and_elicitation_test.go new file mode 100644 index 000000000..35814ef1d --- /dev/null +++ b/go/internal/e2e/commands_and_elicitation_test.go @@ -0,0 +1,357 @@ +package e2e + +import ( + "fmt" + "strings" + "testing" + "time" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +func TestCommands(t *testing.T) { + ctx := testharness.NewTestContext(t) + client1 := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + }) + t.Cleanup(func() { client1.ForceStop() }) + + // Start client1 with an init session to get the port + initSession, err := client1.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create init session: %v", err) + } + initSession.Disconnect() + + actualPort := client1.ActualPort() + if actualPort == 0 { + t.Fatalf("Expected non-zero port from TCP mode client") + } + + client2 := copilot.NewClient(&copilot.ClientOptions{ + CLIUrl: fmt.Sprintf("localhost:%d", actualPort), + }) + t.Cleanup(func() { client2.ForceStop() }) + + t.Run("commands.changed event when another client joins with commands", func(t *testing.T) { + ctx.ConfigureForTest(t) + + // Client1 creates a session without commands + session1, err := client1.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + // Listen for commands.changed event on client1 + commandsChangedCh := make(chan copilot.SessionEvent, 1) + unsubscribe := session1.On(func(event copilot.SessionEvent) { + if event.Type == copilot.SessionEventTypeCommandsChanged { + select { + case commandsChangedCh <- event: + default: + } + } + }) + defer unsubscribe() + + // Client2 joins with commands + session2, err := client2.ResumeSession(t.Context(), session1.SessionID, &copilot.ResumeSessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + DisableResume: true, + Commands: []copilot.CommandDefinition{ + { + Name: "deploy", + Description: "Deploy the app", + Handler: func(ctx copilot.CommandContext) error { return nil }, + }, + }, + }) + if err != nil { + t.Fatalf("Failed to resume session: %v", err) + } + + select { + case event := <-commandsChangedCh: + if len(event.Data.Commands) == 0 { + t.Errorf("Expected commands in commands.changed event") + } else { + found := false + for _, cmd := range event.Data.Commands { + if cmd.Name == "deploy" { + found = true + if cmd.Description == nil || *cmd.Description != "Deploy the app" { + t.Errorf("Expected deploy command description 'Deploy the app', got %v", cmd.Description) + } + break + } + } + if !found { + t.Errorf("Expected 'deploy' command in commands.changed event, got %+v", event.Data.Commands) + } + } + case <-time.After(30 * time.Second): + t.Fatal("Timed out waiting for commands.changed event") + } + + session2.Disconnect() + }) +} + +func TestUIElicitation(t *testing.T) { + ctx := testharness.NewTestContext(t) + client := ctx.NewClient() + t.Cleanup(func() { client.ForceStop() }) + + t.Run("elicitation methods error in headless mode", func(t *testing.T) { + ctx.ConfigureForTest(t) + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + // Verify capabilities report no elicitation + caps := session.Capabilities() + if caps.UI != nil && caps.UI.Elicitation { + t.Error("Expected no elicitation capability in headless mode") + } + + // All UI methods should return a "not supported" error + ui := session.UI() + + _, err = ui.Confirm(t.Context(), "Are you sure?") + if err == nil { + t.Error("Expected error calling Confirm without elicitation capability") + } else if !strings.Contains(err.Error(), "not supported") { + t.Errorf("Expected 'not supported' in error message, got: %s", err.Error()) + } + + _, _, err = ui.Select(t.Context(), "Pick one", []string{"a", "b"}) + if err == nil { + t.Error("Expected error calling Select without elicitation capability") + } else if !strings.Contains(err.Error(), "not supported") { + t.Errorf("Expected 'not supported' in error message, got: %s", err.Error()) + } + + _, _, err = ui.Input(t.Context(), "Enter name", nil) + if err == nil { + t.Error("Expected error calling Input without elicitation capability") + } else if !strings.Contains(err.Error(), "not supported") { + t.Errorf("Expected 'not supported' in error message, got: %s", err.Error()) + } + }) +} + +func TestUIElicitationCallback(t *testing.T) { + ctx := testharness.NewTestContext(t) + client := ctx.NewClient() + t.Cleanup(func() { client.ForceStop() }) + + t.Run("session with OnElicitationRequest reports elicitation capability", func(t *testing.T) { + ctx.ConfigureForTest(t) + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + OnElicitationRequest: func(req copilot.ElicitationRequest, inv copilot.ElicitationInvocation) (copilot.ElicitationResult, error) { + return copilot.ElicitationResult{Action: "accept", Content: map[string]any{}}, nil + }, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + caps := session.Capabilities() + if caps.UI == nil || !caps.UI.Elicitation { + // The test harness may or may not include capabilities in the response. + // When running against a real CLI, this will be true. + t.Logf("Note: capabilities.ui.elicitation=%v (may be false with test harness)", caps.UI) + } + }) + + t.Run("session without OnElicitationRequest reports no elicitation capability", func(t *testing.T) { + ctx.ConfigureForTest(t) + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + caps := session.Capabilities() + if caps.UI != nil && caps.UI.Elicitation { + t.Error("Expected no elicitation capability when OnElicitationRequest is not provided") + } + }) +} + +func TestUIElicitationMultiClient(t *testing.T) { + ctx := testharness.NewTestContext(t) + client1 := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + }) + t.Cleanup(func() { client1.ForceStop() }) + + // Start client1 with an init session to get the port + initSession, err := client1.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create init session: %v", err) + } + initSession.Disconnect() + + actualPort := client1.ActualPort() + if actualPort == 0 { + t.Fatalf("Expected non-zero port from TCP mode client") + } + + t.Run("capabilities.changed fires when second client joins with elicitation handler", func(t *testing.T) { + ctx.ConfigureForTest(t) + + // Client1 creates a session without elicitation handler + session1, err := client1.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + // Verify initial state: no elicitation capability + caps := session1.Capabilities() + if caps.UI != nil && caps.UI.Elicitation { + t.Error("Expected no elicitation capability before second client joins") + } + + // Listen for capabilities.changed with elicitation enabled + capEnabledCh := make(chan copilot.SessionEvent, 1) + unsubscribe := session1.On(func(event copilot.SessionEvent) { + if event.Type == copilot.SessionEventTypeCapabilitiesChanged { + if event.Data.UI != nil && event.Data.UI.Elicitation != nil && *event.Data.UI.Elicitation { + select { + case capEnabledCh <- event: + default: + } + } + } + }) + + // Client2 joins with elicitation handler — should trigger capabilities.changed + client2 := copilot.NewClient(&copilot.ClientOptions{ + CLIUrl: fmt.Sprintf("localhost:%d", actualPort), + }) + session2, err := client2.ResumeSession(t.Context(), session1.SessionID, &copilot.ResumeSessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + DisableResume: true, + OnElicitationRequest: func(req copilot.ElicitationRequest, inv copilot.ElicitationInvocation) (copilot.ElicitationResult, error) { + return copilot.ElicitationResult{Action: "accept", Content: map[string]any{}}, nil + }, + }) + if err != nil { + client2.ForceStop() + t.Fatalf("Failed to resume session: %v", err) + } + + // Wait for the elicitation-enabled capabilities.changed event + select { + case capEvent := <-capEnabledCh: + if capEvent.Data.UI == nil || capEvent.Data.UI.Elicitation == nil || !*capEvent.Data.UI.Elicitation { + t.Errorf("Expected capabilities.changed with ui.elicitation=true, got %+v", capEvent.Data.UI) + } + case <-time.After(30 * time.Second): + t.Fatal("Timed out waiting for capabilities.changed event (elicitation enabled)") + } + + unsubscribe() + session2.Disconnect() + client2.ForceStop() + }) + + t.Run("capabilities.changed fires when elicitation provider disconnects", func(t *testing.T) { + ctx.ConfigureForTest(t) + + // Client1 creates a session without elicitation handler + session1, err := client1.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + // Verify initial state: no elicitation capability + caps := session1.Capabilities() + if caps.UI != nil && caps.UI.Elicitation { + t.Error("Expected no elicitation capability before provider joins") + } + + // Listen for capability enabled + capEnabledCh := make(chan struct{}, 1) + unsubEnabled := session1.On(func(event copilot.SessionEvent) { + if event.Type == copilot.SessionEventTypeCapabilitiesChanged { + if event.Data.UI != nil && event.Data.UI.Elicitation != nil && *event.Data.UI.Elicitation { + select { + case capEnabledCh <- struct{}{}: + default: + } + } + } + }) + + // Client3 (dedicated for this test) joins with elicitation handler + client3 := copilot.NewClient(&copilot.ClientOptions{ + CLIUrl: fmt.Sprintf("localhost:%d", actualPort), + }) + _, err = client3.ResumeSession(t.Context(), session1.SessionID, &copilot.ResumeSessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + DisableResume: true, + OnElicitationRequest: func(req copilot.ElicitationRequest, inv copilot.ElicitationInvocation) (copilot.ElicitationResult, error) { + return copilot.ElicitationResult{Action: "accept", Content: map[string]any{}}, nil + }, + }) + if err != nil { + client3.ForceStop() + t.Fatalf("Failed to resume session for client3: %v", err) + } + + // Wait for elicitation to become enabled + select { + case <-capEnabledCh: + // Good — elicitation is now enabled + case <-time.After(30 * time.Second): + client3.ForceStop() + t.Fatal("Timed out waiting for capabilities.changed event (elicitation enabled)") + } + unsubEnabled() + + // Now listen for elicitation to become disabled + capDisabledCh := make(chan struct{}, 1) + unsubDisabled := session1.On(func(event copilot.SessionEvent) { + if event.Type == copilot.SessionEventTypeCapabilitiesChanged { + if event.Data.UI != nil && event.Data.UI.Elicitation != nil && !*event.Data.UI.Elicitation { + select { + case capDisabledCh <- struct{}{}: + default: + } + } + } + }) + + // Disconnect client3 — should trigger capabilities.changed with elicitation=false + client3.ForceStop() + + select { + case <-capDisabledCh: + // Good — got the disabled event + case <-time.After(30 * time.Second): + t.Fatal("Timed out waiting for capabilities.changed event (elicitation disabled)") + } + unsubDisabled() + }) +} diff --git a/go/session.go b/go/session.go index 5be626b52..d8aec8b96 100644 --- a/go/session.go +++ b/go/session.go @@ -66,6 +66,12 @@ type Session struct { hooksMux sync.RWMutex transformCallbacks map[string]SectionTransformFn transformMu sync.Mutex + commandHandlers map[string]CommandHandler + commandHandlersMu sync.RWMutex + elicitationHandler ElicitationHandler + elicitationMu sync.RWMutex + capabilities SessionCapabilities + capabilitiesMu sync.RWMutex // eventCh serializes user event handler dispatch. dispatchEvent enqueues; // a single goroutine (processEvents) dequeues and invokes handlers in FIFO order. @@ -86,13 +92,14 @@ func (s *Session) WorkspacePath() string { // newSession creates a new session wrapper with the given session ID and client. func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) *Session { s := &Session{ - SessionID: sessionID, - workspacePath: workspacePath, - client: client, - handlers: make([]sessionHandler, 0), - toolHandlers: make(map[string]ToolHandler), - eventCh: make(chan SessionEvent, 128), - RPC: rpc.NewSessionRpc(client, sessionID), + SessionID: sessionID, + workspacePath: workspacePath, + client: client, + handlers: make([]sessionHandler, 0), + toolHandlers: make(map[string]ToolHandler), + commandHandlers: make(map[string]CommandHandler), + eventCh: make(chan SessionEvent, 128), + RPC: rpc.NewSessionRpc(client, sessionID), } go s.processEvents() return s @@ -498,6 +505,334 @@ func (s *Session) handleSystemMessageTransform(sections map[string]systemMessage return systemMessageTransformResponse{Sections: result}, nil } +// registerCommands registers command handlers for this session. +func (s *Session) registerCommands(commands []CommandDefinition) { + s.commandHandlersMu.Lock() + defer s.commandHandlersMu.Unlock() + s.commandHandlers = make(map[string]CommandHandler) + for _, cmd := range commands { + if cmd.Name == "" || cmd.Handler == nil { + continue + } + s.commandHandlers[cmd.Name] = cmd.Handler + } +} + +// getCommandHandler retrieves a registered command handler by name. +func (s *Session) getCommandHandler(name string) (CommandHandler, bool) { + s.commandHandlersMu.RLock() + handler, ok := s.commandHandlers[name] + s.commandHandlersMu.RUnlock() + return handler, ok +} + +// executeCommandAndRespond dispatches a command.execute event to the registered handler +// and sends the result (or error) back via the RPC layer. +func (s *Session) executeCommandAndRespond(requestID, commandName, command, args string) { + ctx := context.Background() + handler, ok := s.getCommandHandler(commandName) + if !ok { + errMsg := fmt.Sprintf("Unknown command: %s", commandName) + s.RPC.Commands.HandlePendingCommand(ctx, &rpc.SessionCommandsHandlePendingCommandParams{ + RequestID: requestID, + Error: &errMsg, + }) + return + } + + cmdCtx := CommandContext{ + SessionID: s.SessionID, + Command: command, + CommandName: commandName, + Args: args, + } + + if err := handler(cmdCtx); err != nil { + errMsg := err.Error() + s.RPC.Commands.HandlePendingCommand(ctx, &rpc.SessionCommandsHandlePendingCommandParams{ + RequestID: requestID, + Error: &errMsg, + }) + return + } + + s.RPC.Commands.HandlePendingCommand(ctx, &rpc.SessionCommandsHandlePendingCommandParams{ + RequestID: requestID, + }) +} + +// registerElicitationHandler registers an elicitation handler for this session. +func (s *Session) registerElicitationHandler(handler ElicitationHandler) { + s.elicitationMu.Lock() + defer s.elicitationMu.Unlock() + s.elicitationHandler = handler +} + +// getElicitationHandler returns the currently registered elicitation handler, or nil. +func (s *Session) getElicitationHandler() ElicitationHandler { + s.elicitationMu.RLock() + defer s.elicitationMu.RUnlock() + return s.elicitationHandler +} + +// handleElicitationRequest dispatches an elicitation.requested event to the registered handler +// and sends the result back via the RPC layer. Auto-cancels on error. +func (s *Session) handleElicitationRequest(request ElicitationRequest, requestID string) { + handler := s.getElicitationHandler() + if handler == nil { + return + } + + ctx := context.Background() + invocation := ElicitationInvocation{SessionID: s.SessionID} + + result, err := handler(request, invocation) + if err != nil { + // Handler failed — attempt to cancel so the request doesn't hang. + s.RPC.Ui.HandlePendingElicitation(ctx, &rpc.SessionUIHandlePendingElicitationParams{ + RequestID: requestID, + Result: rpc.SessionUIHandlePendingElicitationParamsResult{ + Action: rpc.ActionCancel, + }, + }) + return + } + + rpcContent := make(map[string]*rpc.Content) + for k, v := range result.Content { + rpcContent[k] = toRPCContent(v) + } + + s.RPC.Ui.HandlePendingElicitation(ctx, &rpc.SessionUIHandlePendingElicitationParams{ + RequestID: requestID, + Result: rpc.SessionUIHandlePendingElicitationParamsResult{ + Action: rpc.Action(result.Action), + Content: rpcContent, + }, + }) +} + +// toRPCContent converts an arbitrary value to a *rpc.Content for elicitation responses. +func toRPCContent(v any) *rpc.Content { + if v == nil { + return nil + } + c := &rpc.Content{} + switch val := v.(type) { + case bool: + c.Bool = &val + case float64: + c.Double = &val + case int: + f := float64(val) + c.Double = &f + case string: + c.String = &val + case []string: + c.StringArray = val + case []any: + strs := make([]string, 0, len(val)) + for _, item := range val { + if s, ok := item.(string); ok { + strs = append(strs, s) + } + } + c.StringArray = strs + default: + s := fmt.Sprintf("%v", val) + c.String = &s + } + return c +} + +// Capabilities returns the session capabilities reported by the server. +func (s *Session) Capabilities() SessionCapabilities { + s.capabilitiesMu.RLock() + defer s.capabilitiesMu.RUnlock() + return s.capabilities +} + +// setCapabilities updates the session capabilities. +func (s *Session) setCapabilities(caps *SessionCapabilities) { + s.capabilitiesMu.Lock() + defer s.capabilitiesMu.Unlock() + if caps != nil { + s.capabilities = *caps + } else { + s.capabilities = SessionCapabilities{} + } +} + +// UI returns the interactive UI API for showing elicitation dialogs. +// Methods on the returned SessionUI will error if the host does not support +// elicitation (check Capabilities().UI.Elicitation first). +func (s *Session) UI() *SessionUI { + return &SessionUI{session: s} +} + +// assertElicitation checks that the host supports elicitation and returns an error if not. +func (s *Session) assertElicitation() error { + caps := s.Capabilities() + if caps.UI == nil || !caps.UI.Elicitation { + return fmt.Errorf("elicitation is not supported by the host; check session.Capabilities().UI.Elicitation before calling UI methods") + } + return nil +} + +// Elicitation shows a generic elicitation dialog with a custom schema. +func (ui *SessionUI) Elicitation(ctx context.Context, message string, requestedSchema rpc.RequestedSchema) (*ElicitationResult, error) { + if err := ui.session.assertElicitation(); err != nil { + return nil, err + } + rpcResult, err := ui.session.RPC.Ui.Elicitation(ctx, &rpc.SessionUIElicitationParams{ + Message: message, + RequestedSchema: requestedSchema, + }) + if err != nil { + return nil, err + } + return fromRPCElicitationResult(rpcResult), nil +} + +// Confirm shows a confirmation dialog and returns the user's boolean answer. +// Returns false if the user declines or cancels. +func (ui *SessionUI) Confirm(ctx context.Context, message string) (bool, error) { + if err := ui.session.assertElicitation(); err != nil { + return false, err + } + defaultTrue := &rpc.Content{Bool: Bool(true)} + rpcResult, err := ui.session.RPC.Ui.Elicitation(ctx, &rpc.SessionUIElicitationParams{ + Message: message, + RequestedSchema: rpc.RequestedSchema{ + Type: rpc.RequestedSchemaTypeObject, + Properties: map[string]rpc.Property{ + "confirmed": { + Type: rpc.PropertyTypeBoolean, + Default: defaultTrue, + }, + }, + Required: []string{"confirmed"}, + }, + }) + if err != nil { + return false, err + } + if rpcResult.Action == rpc.ActionAccept { + if c, ok := rpcResult.Content["confirmed"]; ok && c != nil && c.Bool != nil { + return *c.Bool, nil + } + } + return false, nil +} + +// Select shows a selection dialog with the given options. +// Returns the selected string, or empty string and false if the user declines/cancels. +func (ui *SessionUI) Select(ctx context.Context, message string, options []string) (string, bool, error) { + if err := ui.session.assertElicitation(); err != nil { + return "", false, err + } + rpcResult, err := ui.session.RPC.Ui.Elicitation(ctx, &rpc.SessionUIElicitationParams{ + Message: message, + RequestedSchema: rpc.RequestedSchema{ + Type: rpc.RequestedSchemaTypeObject, + Properties: map[string]rpc.Property{ + "selection": { + Type: rpc.PropertyTypeString, + Enum: options, + }, + }, + Required: []string{"selection"}, + }, + }) + if err != nil { + return "", false, err + } + if rpcResult.Action == rpc.ActionAccept { + if c, ok := rpcResult.Content["selection"]; ok && c != nil && c.String != nil { + return *c.String, true, nil + } + } + return "", false, nil +} + +// Input shows a text input dialog. Returns the entered text, or empty string and +// false if the user declines/cancels. +func (ui *SessionUI) Input(ctx context.Context, message string, opts *InputOptions) (string, bool, error) { + if err := ui.session.assertElicitation(); err != nil { + return "", false, err + } + prop := rpc.Property{Type: rpc.PropertyTypeString} + if opts != nil { + if opts.Title != "" { + prop.Title = &opts.Title + } + if opts.Description != "" { + prop.Description = &opts.Description + } + if opts.MinLength != nil { + f := float64(*opts.MinLength) + prop.MinLength = &f + } + if opts.MaxLength != nil { + f := float64(*opts.MaxLength) + prop.MaxLength = &f + } + if opts.Format != "" { + format := rpc.Format(opts.Format) + prop.Format = &format + } + if opts.Default != "" { + prop.Default = &rpc.Content{String: &opts.Default} + } + } + rpcResult, err := ui.session.RPC.Ui.Elicitation(ctx, &rpc.SessionUIElicitationParams{ + Message: message, + RequestedSchema: rpc.RequestedSchema{ + Type: rpc.RequestedSchemaTypeObject, + Properties: map[string]rpc.Property{ + "value": prop, + }, + Required: []string{"value"}, + }, + }) + if err != nil { + return "", false, err + } + if rpcResult.Action == rpc.ActionAccept { + if c, ok := rpcResult.Content["value"]; ok && c != nil && c.String != nil { + return *c.String, true, nil + } + } + return "", false, nil +} + +// fromRPCElicitationResult converts the RPC result to the SDK ElicitationResult. +func fromRPCElicitationResult(r *rpc.SessionUIElicitationResult) *ElicitationResult { + if r == nil { + return nil + } + content := make(map[string]any) + for k, v := range r.Content { + if v == nil { + content[k] = nil + continue + } + if v.Bool != nil { + content[k] = *v.Bool + } else if v.Double != nil { + content[k] = *v.Double + } else if v.String != nil { + content[k] = *v.String + } else if v.StringArray != nil { + content[k] = v.StringArray + } + } + return &ElicitationResult{ + Action: string(r.Action), + Content: content, + } +} + // dispatchEvent enqueues an event for delivery to user handlers and fires // broadcast handlers concurrently. // @@ -586,6 +921,62 @@ func (s *Session) handleBroadcastEvent(event SessionEvent) { return } s.executePermissionAndRespond(*requestID, *event.Data.PermissionRequest, handler) + + case SessionEventTypeCommandExecute: + requestID := event.Data.RequestID + if requestID == nil { + return + } + commandName := "" + if event.Data.CommandName != nil { + commandName = *event.Data.CommandName + } + command := "" + if event.Data.Command != nil { + command = *event.Data.Command + } + args := "" + if event.Data.Args != nil { + args = *event.Data.Args + } + s.executeCommandAndRespond(*requestID, commandName, command, args) + + case SessionEventTypeElicitationRequested: + requestID := event.Data.RequestID + if requestID == nil { + return + } + handler := s.getElicitationHandler() + if handler == nil { + return + } + message := "" + if event.Data.Message != nil { + message = *event.Data.Message + } + var requestedSchema map[string]any + if event.Data.RequestedSchema != nil { + requestedSchema = event.Data.RequestedSchema.Properties + } + mode := "" + if event.Data.Mode != nil { + mode = string(*event.Data.Mode) + } + elicitationSource := "" + if event.Data.ElicitationSource != nil { + elicitationSource = *event.Data.ElicitationSource + } + url := "" + if event.Data.URL != nil { + url = *event.Data.URL + } + s.handleElicitationRequest(ElicitationRequest{ + Message: message, + RequestedSchema: requestedSchema, + Mode: mode, + ElicitationSource: elicitationSource, + URL: url, + }, *requestID) } } @@ -748,6 +1139,14 @@ func (s *Session) Disconnect() error { s.permissionHandler = nil s.permissionMux.Unlock() + s.commandHandlersMu.Lock() + s.commandHandlers = nil + s.commandHandlersMu.Unlock() + + s.elicitationMu.Lock() + s.elicitationHandler = nil + s.elicitationMu.Unlock() + return nil } diff --git a/go/session_test.go b/go/session_test.go index 664c06e55..dc9d3c209 100644 --- a/go/session_test.go +++ b/go/session_test.go @@ -11,8 +11,9 @@ import ( // Returns a cleanup function that closes the channel (stopping the consumer). func newTestSession() (*Session, func()) { s := &Session{ - handlers: make([]sessionHandler, 0), - eventCh: make(chan SessionEvent, 128), + handlers: make([]sessionHandler, 0), + commandHandlers: make(map[string]CommandHandler), + eventCh: make(chan SessionEvent, 128), } go s.processEvents() return s, func() { close(s.eventCh) } @@ -204,3 +205,192 @@ func TestSession_On(t *testing.T) { } }) } + +func TestSession_CommandRouting(t *testing.T) { + t.Run("routes command.execute event to the correct handler", func(t *testing.T) { + session, cleanup := newTestSession() + defer cleanup() + + var receivedCtx CommandContext + session.registerCommands([]CommandDefinition{ + { + Name: "deploy", + Description: "Deploy the app", + Handler: func(ctx CommandContext) error { + receivedCtx = ctx + return nil + }, + }, + { + Name: "rollback", + Description: "Rollback", + Handler: func(ctx CommandContext) error { + return nil + }, + }, + }) + + // Simulate the dispatch — executeCommandAndRespond will fail on RPC (nil client) + // but the handler will still be invoked. We test routing only. + _, ok := session.getCommandHandler("deploy") + if !ok { + t.Fatal("Expected 'deploy' handler to be registered") + } + _, ok = session.getCommandHandler("rollback") + if !ok { + t.Fatal("Expected 'rollback' handler to be registered") + } + _, ok = session.getCommandHandler("nonexistent") + if ok { + t.Fatal("Expected 'nonexistent' handler to NOT be registered") + } + + // Directly invoke handler to verify context is correct + handler, _ := session.getCommandHandler("deploy") + err := handler(CommandContext{ + SessionID: "test-session", + Command: "/deploy production", + CommandName: "deploy", + Args: "production", + }) + if err != nil { + t.Fatalf("Handler returned error: %v", err) + } + if receivedCtx.SessionID != "test-session" { + t.Errorf("Expected sessionID 'test-session', got %q", receivedCtx.SessionID) + } + if receivedCtx.CommandName != "deploy" { + t.Errorf("Expected commandName 'deploy', got %q", receivedCtx.CommandName) + } + if receivedCtx.Command != "/deploy production" { + t.Errorf("Expected command '/deploy production', got %q", receivedCtx.Command) + } + if receivedCtx.Args != "production" { + t.Errorf("Expected args 'production', got %q", receivedCtx.Args) + } + }) + + t.Run("skips commands with empty name or nil handler", func(t *testing.T) { + session, cleanup := newTestSession() + defer cleanup() + + session.registerCommands([]CommandDefinition{ + {Name: "", Handler: func(ctx CommandContext) error { return nil }}, + {Name: "valid", Handler: nil}, + {Name: "good", Handler: func(ctx CommandContext) error { return nil }}, + }) + + _, ok := session.getCommandHandler("") + if ok { + t.Error("Empty name should not be registered") + } + _, ok = session.getCommandHandler("valid") + if ok { + t.Error("Nil handler should not be registered") + } + _, ok = session.getCommandHandler("good") + if !ok { + t.Error("Expected 'good' handler to be registered") + } + }) +} + +func TestSession_Capabilities(t *testing.T) { + t.Run("defaults capabilities when not injected", func(t *testing.T) { + session, cleanup := newTestSession() + defer cleanup() + + caps := session.Capabilities() + if caps.UI != nil { + t.Errorf("Expected UI to be nil by default, got %+v", caps.UI) + } + }) + + t.Run("setCapabilities stores and retrieves capabilities", func(t *testing.T) { + session, cleanup := newTestSession() + defer cleanup() + + session.setCapabilities(&SessionCapabilities{ + UI: &UICapabilities{Elicitation: true}, + }) + caps := session.Capabilities() + if caps.UI == nil || !caps.UI.Elicitation { + t.Errorf("Expected UI.Elicitation to be true") + } + }) + + t.Run("setCapabilities with nil resets to empty", func(t *testing.T) { + session, cleanup := newTestSession() + defer cleanup() + + session.setCapabilities(&SessionCapabilities{ + UI: &UICapabilities{Elicitation: true}, + }) + session.setCapabilities(nil) + caps := session.Capabilities() + if caps.UI != nil { + t.Errorf("Expected UI to be nil after reset, got %+v", caps.UI) + } + }) +} + +func TestSession_ElicitationCapabilityGating(t *testing.T) { + t.Run("elicitation errors when capability is missing", func(t *testing.T) { + session, cleanup := newTestSession() + defer cleanup() + + err := session.assertElicitation() + if err == nil { + t.Fatal("Expected error when elicitation capability is missing") + } + expected := "elicitation is not supported" + if !containsString(err.Error(), expected) { + t.Errorf("Expected error to contain %q, got %q", expected, err.Error()) + } + }) + + t.Run("elicitation succeeds when capability is present", func(t *testing.T) { + session, cleanup := newTestSession() + defer cleanup() + + session.setCapabilities(&SessionCapabilities{ + UI: &UICapabilities{Elicitation: true}, + }) + err := session.assertElicitation() + if err != nil { + t.Errorf("Expected no error when elicitation capability is present, got %v", err) + } + }) +} + +func TestSession_ElicitationHandler(t *testing.T) { + t.Run("registerElicitationHandler stores handler", func(t *testing.T) { + session, cleanup := newTestSession() + defer cleanup() + + if session.getElicitationHandler() != nil { + t.Error("Expected nil handler before registration") + } + + session.registerElicitationHandler(func(req ElicitationRequest, inv ElicitationInvocation) (ElicitationResult, error) { + return ElicitationResult{Action: "accept"}, nil + }) + + if session.getElicitationHandler() == nil { + t.Error("Expected non-nil handler after registration") + } + }) +} + +func containsString(s, substr string) bool { + return len(s) >= len(substr) && searchSubstring(s, substr) +} + +func searchSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/go/types.go b/go/types.go index f888c9b6e..8c5d92dc8 100644 --- a/go/types.go +++ b/go/types.go @@ -111,6 +111,12 @@ func Float64(v float64) *float64 { return &v } +// Int returns a pointer to the given int value. +// Use for setting optional int parameters: MinLength: Int(1) +func Int(v int) *int { + return &v +} + // Known system prompt section identifiers for the "customize" mode. const ( SectionIdentity = "identity" @@ -489,6 +495,14 @@ type SessionConfig struct { // handler. Equivalent to calling session.On(handler) immediately after creation, // but executes earlier in the lifecycle so no events are missed. OnEvent SessionEventHandler + // Commands registers slash-commands for this session. Each command appears as + // /name in the CLI TUI for the user to invoke. The Handler is called when the + // command is executed. + Commands []CommandDefinition + // OnElicitationRequest is a handler for elicitation requests from the server. + // When provided, the server may call back to this client for form-based UI dialogs + // (e.g. from MCP tools). Also enables the elicitation capability on the session. + OnElicitationRequest ElicitationHandler } type Tool struct { Name string `json:"name"` @@ -527,6 +541,97 @@ type ToolResult struct { ToolTelemetry map[string]any `json:"toolTelemetry,omitempty"` } +// CommandContext provides context about a slash-command invocation. +type CommandContext struct { + // SessionID is the session where the command was invoked. + SessionID string + // Command is the full command text (e.g. "/deploy production"). + Command string + // CommandName is the command name without the leading / (e.g. "deploy"). + CommandName string + // Args is the raw argument string after the command name. + Args string +} + +// CommandHandler is invoked when a registered slash-command is executed. +type CommandHandler func(ctx CommandContext) error + +// CommandDefinition registers a slash-command. Name is shown in the CLI TUI +// as /name for the user to invoke. +type CommandDefinition struct { + // Name is the command name (without leading /). + Name string + // Description is a human-readable description shown in command completion UI. + Description string + // Handler is invoked when the command is executed. + Handler CommandHandler +} + +// SessionCapabilities describes what features the host supports. +type SessionCapabilities struct { + UI *UICapabilities `json:"ui,omitempty"` +} + +// UICapabilities describes host UI feature support. +type UICapabilities struct { + // Elicitation indicates whether the host supports interactive elicitation dialogs. + Elicitation bool `json:"elicitation,omitempty"` +} + +// ElicitationResult is the user's response to an elicitation dialog. +type ElicitationResult struct { + // Action is the user response: "accept" (submitted), "decline" (rejected), or "cancel" (dismissed). + Action string `json:"action"` + // Content holds form values submitted by the user (present when Action is "accept"). + Content map[string]any `json:"content,omitempty"` +} + +// ElicitationRequest describes an elicitation request from the server. +type ElicitationRequest struct { + // Message describes what information is needed from the user. + Message string + // RequestedSchema is a JSON Schema describing the form fields (form mode only). + RequestedSchema map[string]any + // Mode is "form" for structured input, "url" for browser redirect. + Mode string + // ElicitationSource is the source that initiated the request (e.g. MCP server name). + ElicitationSource string + // URL to open in the user's browser (url mode only). + URL string +} + +// ElicitationHandler handles elicitation requests from the server (e.g. from MCP tools). +// It receives the request and an ElicitationInvocation for context, and must return +// an ElicitationResult. If the handler returns an error the SDK auto-cancels the request. +type ElicitationHandler func(request ElicitationRequest, invocation ElicitationInvocation) (ElicitationResult, error) + +// ElicitationInvocation provides context about an elicitation request. +type ElicitationInvocation struct { + SessionID string +} + +// InputOptions configures a text input field for the Input convenience method. +type InputOptions struct { + // Title label for the input field. + Title string + // Description text shown below the field. + Description string + // MinLength is the minimum character length. + MinLength *int + // MaxLength is the maximum character length. + MaxLength *int + // Format is a semantic format hint: "email", "uri", "date", or "date-time". + Format string + // Default is the pre-populated value. + Default string +} + +// SessionUI provides convenience methods for showing elicitation dialogs to the user. +// Obtained via [Session.UI]. Methods error if the host does not support elicitation. +type SessionUI struct { + session *Session +} + // ResumeSessionConfig configures options when resuming a session type ResumeSessionConfig struct { // ClientName identifies the application using the SDK. @@ -585,6 +690,11 @@ type ResumeSessionConfig struct { // OnEvent is an optional event handler registered before the session.resume RPC // is issued, ensuring early events are delivered. See SessionConfig.OnEvent. OnEvent SessionEventHandler + // Commands registers slash-commands for this session. See SessionConfig.Commands. + Commands []CommandDefinition + // OnElicitationRequest is a handler for elicitation requests from the server. + // See SessionConfig.OnElicitationRequest. + OnElicitationRequest ElicitationHandler } type ProviderConfig struct { // Type is the provider type: "openai", "azure", or "anthropic". Defaults to "openai". @@ -742,71 +852,83 @@ type SessionLifecycleHandler func(event SessionLifecycleEvent) // createSessionRequest is the request for session.create type createSessionRequest struct { - Model string `json:"model,omitempty"` - SessionID string `json:"sessionId,omitempty"` - ClientName string `json:"clientName,omitempty"` - ReasoningEffort string `json:"reasoningEffort,omitempty"` - Tools []Tool `json:"tools,omitempty"` - SystemMessage *SystemMessageConfig `json:"systemMessage,omitempty"` - AvailableTools []string `json:"availableTools"` - ExcludedTools []string `json:"excludedTools,omitempty"` - Provider *ProviderConfig `json:"provider,omitempty"` - RequestPermission *bool `json:"requestPermission,omitempty"` - RequestUserInput *bool `json:"requestUserInput,omitempty"` - Hooks *bool `json:"hooks,omitempty"` - WorkingDirectory string `json:"workingDirectory,omitempty"` - Streaming *bool `json:"streaming,omitempty"` - MCPServers map[string]MCPServerConfig `json:"mcpServers,omitempty"` - EnvValueMode string `json:"envValueMode,omitempty"` - CustomAgents []CustomAgentConfig `json:"customAgents,omitempty"` - Agent string `json:"agent,omitempty"` - ConfigDir string `json:"configDir,omitempty"` - SkillDirectories []string `json:"skillDirectories,omitempty"` - DisabledSkills []string `json:"disabledSkills,omitempty"` - InfiniteSessions *InfiniteSessionConfig `json:"infiniteSessions,omitempty"` - Traceparent string `json:"traceparent,omitempty"` - Tracestate string `json:"tracestate,omitempty"` + Model string `json:"model,omitempty"` + SessionID string `json:"sessionId,omitempty"` + ClientName string `json:"clientName,omitempty"` + ReasoningEffort string `json:"reasoningEffort,omitempty"` + Tools []Tool `json:"tools,omitempty"` + SystemMessage *SystemMessageConfig `json:"systemMessage,omitempty"` + AvailableTools []string `json:"availableTools"` + ExcludedTools []string `json:"excludedTools,omitempty"` + Provider *ProviderConfig `json:"provider,omitempty"` + RequestPermission *bool `json:"requestPermission,omitempty"` + RequestUserInput *bool `json:"requestUserInput,omitempty"` + Hooks *bool `json:"hooks,omitempty"` + WorkingDirectory string `json:"workingDirectory,omitempty"` + Streaming *bool `json:"streaming,omitempty"` + MCPServers map[string]MCPServerConfig `json:"mcpServers,omitempty"` + EnvValueMode string `json:"envValueMode,omitempty"` + CustomAgents []CustomAgentConfig `json:"customAgents,omitempty"` + Agent string `json:"agent,omitempty"` + ConfigDir string `json:"configDir,omitempty"` + SkillDirectories []string `json:"skillDirectories,omitempty"` + DisabledSkills []string `json:"disabledSkills,omitempty"` + InfiniteSessions *InfiniteSessionConfig `json:"infiniteSessions,omitempty"` + Commands []wireCommand `json:"commands,omitempty"` + RequestElicitation *bool `json:"requestElicitation,omitempty"` + Traceparent string `json:"traceparent,omitempty"` + Tracestate string `json:"tracestate,omitempty"` +} + +// wireCommand is the wire representation of a command (name + description only, no handler). +type wireCommand struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` } // createSessionResponse is the response from session.create type createSessionResponse struct { - SessionID string `json:"sessionId"` - WorkspacePath string `json:"workspacePath"` + SessionID string `json:"sessionId"` + WorkspacePath string `json:"workspacePath"` + Capabilities *SessionCapabilities `json:"capabilities,omitempty"` } // resumeSessionRequest is the request for session.resume type resumeSessionRequest struct { - SessionID string `json:"sessionId"` - ClientName string `json:"clientName,omitempty"` - Model string `json:"model,omitempty"` - ReasoningEffort string `json:"reasoningEffort,omitempty"` - Tools []Tool `json:"tools,omitempty"` - SystemMessage *SystemMessageConfig `json:"systemMessage,omitempty"` - AvailableTools []string `json:"availableTools"` - ExcludedTools []string `json:"excludedTools,omitempty"` - Provider *ProviderConfig `json:"provider,omitempty"` - RequestPermission *bool `json:"requestPermission,omitempty"` - RequestUserInput *bool `json:"requestUserInput,omitempty"` - Hooks *bool `json:"hooks,omitempty"` - WorkingDirectory string `json:"workingDirectory,omitempty"` - ConfigDir string `json:"configDir,omitempty"` - DisableResume *bool `json:"disableResume,omitempty"` - Streaming *bool `json:"streaming,omitempty"` - MCPServers map[string]MCPServerConfig `json:"mcpServers,omitempty"` - EnvValueMode string `json:"envValueMode,omitempty"` - CustomAgents []CustomAgentConfig `json:"customAgents,omitempty"` - Agent string `json:"agent,omitempty"` - SkillDirectories []string `json:"skillDirectories,omitempty"` - DisabledSkills []string `json:"disabledSkills,omitempty"` - InfiniteSessions *InfiniteSessionConfig `json:"infiniteSessions,omitempty"` - Traceparent string `json:"traceparent,omitempty"` - Tracestate string `json:"tracestate,omitempty"` + SessionID string `json:"sessionId"` + ClientName string `json:"clientName,omitempty"` + Model string `json:"model,omitempty"` + ReasoningEffort string `json:"reasoningEffort,omitempty"` + Tools []Tool `json:"tools,omitempty"` + SystemMessage *SystemMessageConfig `json:"systemMessage,omitempty"` + AvailableTools []string `json:"availableTools"` + ExcludedTools []string `json:"excludedTools,omitempty"` + Provider *ProviderConfig `json:"provider,omitempty"` + RequestPermission *bool `json:"requestPermission,omitempty"` + RequestUserInput *bool `json:"requestUserInput,omitempty"` + Hooks *bool `json:"hooks,omitempty"` + WorkingDirectory string `json:"workingDirectory,omitempty"` + ConfigDir string `json:"configDir,omitempty"` + DisableResume *bool `json:"disableResume,omitempty"` + Streaming *bool `json:"streaming,omitempty"` + MCPServers map[string]MCPServerConfig `json:"mcpServers,omitempty"` + EnvValueMode string `json:"envValueMode,omitempty"` + CustomAgents []CustomAgentConfig `json:"customAgents,omitempty"` + Agent string `json:"agent,omitempty"` + SkillDirectories []string `json:"skillDirectories,omitempty"` + DisabledSkills []string `json:"disabledSkills,omitempty"` + InfiniteSessions *InfiniteSessionConfig `json:"infiniteSessions,omitempty"` + Commands []wireCommand `json:"commands,omitempty"` + RequestElicitation *bool `json:"requestElicitation,omitempty"` + Traceparent string `json:"traceparent,omitempty"` + Tracestate string `json:"tracestate,omitempty"` } // resumeSessionResponse is the response from session.resume type resumeSessionResponse struct { - SessionID string `json:"sessionId"` - WorkspacePath string `json:"workspacePath"` + SessionID string `json:"sessionId"` + WorkspacePath string `json:"workspacePath"` + Capabilities *SessionCapabilities `json:"capabilities,omitempty"` } type hooksInvokeRequest struct { diff --git a/python/README.md b/python/README.md index 33f62c2d4..d6ae16724 100644 --- a/python/README.md +++ b/python/README.md @@ -700,6 +700,146 @@ async with await client.create_session( - `on_session_end` - Cleanup or logging when session ends. - `on_error_occurred` - Handle errors with retry/skip/abort strategies. +## Commands + +Register slash commands that users can invoke from the CLI TUI. When the user types `/commandName`, the SDK dispatches the event to your handler. + +```python +from copilot.session import CommandDefinition, CommandContext, PermissionHandler + +async def handle_deploy(ctx: CommandContext) -> None: + print(f"Deploying with args: {ctx.args}") + # ctx.session_id — the session where the command was invoked + # ctx.command — full command text (e.g. "/deploy production") + # ctx.command_name — command name without leading / (e.g. "deploy") + # ctx.args — raw argument string (e.g. "production") + +async with await client.create_session( + on_permission_request=PermissionHandler.approve_all, + commands=[ + CommandDefinition( + name="deploy", + description="Deploy the app", + handler=handle_deploy, + ), + CommandDefinition( + name="rollback", + description="Rollback to previous version", + handler=lambda ctx: print("Rolling back..."), + ), + ], +) as session: + ... +``` + +Commands can also be provided when resuming a session via `resume_session(commands=[...])`. + +## UI Elicitation + +The `session.ui` API provides convenience methods for asking the user questions through interactive dialogs. These methods are only available when the CLI host supports elicitation — check `session.capabilities` before calling. + +### Capability Check + +```python +ui_caps = session.capabilities.get("ui", {}) +if ui_caps.get("elicitation"): + # Safe to call session.ui methods + ... +``` + +### Confirm + +Shows a yes/no confirmation dialog: + +```python +ok = await session.ui.confirm("Deploy to production?") +if ok: + print("Deploying...") +``` + +### Select + +Shows a selection dialog with a list of options: + +```python +env = await session.ui.select("Choose environment:", ["staging", "production", "dev"]) +if env: + print(f"Selected: {env}") +``` + +### Input + +Shows a text input dialog with optional constraints: + +```python +name = await session.ui.input("Enter your name:") + +# With options +email = await session.ui.input("Enter email:", { + "title": "Email Address", + "description": "We'll use this for notifications", + "format": "email", +}) +``` + +### Custom Elicitation + +For full control, use the `elicitation()` method with a custom JSON schema: + +```python +result = await session.ui.elicitation({ + "message": "Configure deployment", + "requestedSchema": { + "type": "object", + "properties": { + "region": {"type": "string", "enum": ["us-east-1", "eu-west-1"]}, + "replicas": {"type": "number", "minimum": 1, "maximum": 10}, + }, + "required": ["region"], + }, +}) + +if result["action"] == "accept": + region = result["content"]["region"] + replicas = result["content"].get("replicas", 1) +``` + +## Elicitation Request Handler + +When the server (or an MCP tool) needs to ask the end-user a question, it sends an `elicitation.requested` event. Provide an `on_elicitation_request` handler to respond: + +```python +from copilot.session import ElicitationRequest, ElicitationResult, PermissionHandler + +async def handle_elicitation( + request: ElicitationRequest, invocation: dict[str, str] +) -> ElicitationResult: + # request["message"] — what the server is asking + # request.get("requestedSchema") — optional JSON schema for form fields + # request.get("mode") — "form" or "url" + # invocation["session_id"] — the session ID + + print(f"Server asks: {request['message']}") + + # Return the user's response + return { + "action": "accept", # or "decline" or "cancel" + "content": {"answer": "yes"}, + } + +async with await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_elicitation_request=handle_elicitation, +) as session: + ... +``` + +When `on_elicitation_request` is provided, the SDK automatically: +- Sends `requestElicitation: true` to the server during session creation/resumption +- Reports the `elicitation` capability on the session +- Dispatches `elicitation.requested` events to your handler +- Auto-cancels if your handler throws an error (so the server doesn't hang) + ## Requirements - Python 3.11+ diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index 92764c0e8..a654cc69b 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -5,15 +5,37 @@ """ from .client import CopilotClient, ExternalServerConfig, SubprocessConfig -from .session import CopilotSession +from .session import ( + CommandContext, + CommandDefinition, + CopilotSession, + ElicitationHandler, + ElicitationParams, + ElicitationRequest, + ElicitationResult, + InputOptions, + SessionCapabilities, + SessionUiApi, + SessionUiCapabilities, +) from .tools import define_tool __version__ = "0.1.0" __all__ = [ + "CommandContext", + "CommandDefinition", "CopilotClient", "CopilotSession", + "ElicitationHandler", + "ElicitationParams", + "ElicitationRequest", + "ElicitationResult", "ExternalServerConfig", + "InputOptions", + "SessionCapabilities", + "SessionUiApi", + "SessionUiCapabilities", "SubprocessConfig", "define_tool", ] diff --git a/python/copilot/client.py b/python/copilot/client.py index ab8074756..356a5fd59 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -35,8 +35,10 @@ from .generated.rpc import ServerRpc from .generated.session_events import PermissionRequest, SessionEvent, session_event_from_dict from .session import ( + CommandDefinition, CopilotSession, CustomAgentConfig, + ElicitationHandler, InfiniteSessionConfig, MCPServerConfig, ProviderConfig, @@ -1114,6 +1116,8 @@ async def create_session( disabled_skills: list[str] | None = None, infinite_sessions: InfiniteSessionConfig | None = None, on_event: Callable[[SessionEvent], None] | None = None, + commands: list[CommandDefinition] | None = None, + on_elicitation_request: ElicitationHandler | None = None, ) -> CopilotSession: """ Create a new conversation session with the Copilot CLI. @@ -1218,6 +1222,15 @@ async def create_session( if on_user_input_request: payload["requestUserInput"] = True + # Enable elicitation request callback if handler provided + payload["requestElicitation"] = bool(on_elicitation_request) + + # Serialize commands (name + description only) into payload + if commands: + payload["commands"] = [ + {"name": cmd.name, "description": cmd.description} for cmd in commands + ] + # Enable hooks callback if any hook handler provided if hooks and any(hooks.values()): payload["hooks"] = True @@ -1290,9 +1303,12 @@ async def create_session( # events emitted by the CLI (e.g. session.start) are not dropped. session = CopilotSession(actual_session_id, self._client, workspace_path=None) session._register_tools(tools) + session._register_commands(commands) session._register_permission_handler(on_permission_request) if on_user_input_request: session._register_user_input_handler(on_user_input_request) + if on_elicitation_request: + session._register_elicitation_handler(on_elicitation_request) if hooks: session._register_hooks(hooks) if transform_callbacks: @@ -1305,6 +1321,8 @@ async def create_session( try: response = await self._client.request("session.create", payload) session._workspace_path = response.get("workspacePath") + capabilities = response.get("capabilities") + session._set_capabilities(capabilities) except BaseException: with self._sessions_lock: self._sessions.pop(actual_session_id, None) @@ -1337,6 +1355,8 @@ async def resume_session( disabled_skills: list[str] | None = None, infinite_sessions: InfiniteSessionConfig | None = None, on_event: Callable[[SessionEvent], None] | None = None, + commands: list[CommandDefinition] | None = None, + on_elicitation_request: ElicitationHandler | None = None, ) -> CopilotSession: """ Resume an existing conversation session by its ID. @@ -1444,6 +1464,15 @@ async def resume_session( if on_user_input_request: payload["requestUserInput"] = True + # Enable elicitation request callback if handler provided + payload["requestElicitation"] = bool(on_elicitation_request) + + # Serialize commands (name + description only) into payload + if commands: + payload["commands"] = [ + {"name": cmd.name, "description": cmd.description} for cmd in commands + ] + if hooks and any(hooks.values()): payload["hooks"] = True @@ -1494,9 +1523,12 @@ async def resume_session( # events emitted by the CLI (e.g. session.start) are not dropped. session = CopilotSession(session_id, self._client, workspace_path=None) session._register_tools(tools) + session._register_commands(commands) session._register_permission_handler(on_permission_request) if on_user_input_request: session._register_user_input_handler(on_user_input_request) + if on_elicitation_request: + session._register_elicitation_handler(on_elicitation_request) if hooks: session._register_hooks(hooks) if transform_callbacks: @@ -1509,6 +1541,8 @@ async def resume_session( try: response = await self._client.request("session.resume", payload) session._workspace_path = response.get("workspacePath") + capabilities = response.get("capabilities") + session._set_capabilities(capabilities) except BaseException: with self._sessions_lock: self._sessions.pop(session_id, None) @@ -2225,10 +2259,26 @@ def __init__(self, sock_file, sock_obj): self._socket = sock_obj def terminate(self): + import socket as _socket_mod + + # shutdown() sends TCP FIN to the server (triggering + # server-side disconnect detection) and interrupts any + # pending blocking reads on other threads immediately. + try: + self._socket.shutdown(_socket_mod.SHUT_RDWR) + except OSError: + pass # Safe to ignore — socket may already be closed + # Close the file wrapper — makefile() holds its own + # reference to the fd, so socket.close() alone won't + # release the OS resource until the wrapper is closed too. + try: + self.stdin.close() + except OSError: + pass # Safe to ignore — already closed try: self._socket.close() except OSError: - pass + pass # Safe to ignore — already closed def kill(self): self.terminate() diff --git a/python/copilot/session.py b/python/copilot/session.py index 019436f7a..f935b0c90 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -22,15 +22,24 @@ from ._jsonrpc import JsonRpcError, ProcessExitedError from ._telemetry import get_trace_context, trace_context from .generated.rpc import ( + Action, Kind, Level, + Property, + PropertyType, + RequestedSchema, + RequestedSchemaType, ResultResult, + SessionCommandsHandlePendingCommandParams, SessionLogParams, SessionModelSwitchToParams, SessionPermissionsHandlePendingPermissionRequestParams, SessionPermissionsHandlePendingPermissionRequestParamsResult, SessionRpc, SessionToolsHandlePendingToolCallParams, + SessionUIElicitationParams, + SessionUIHandlePendingElicitationParams, + SessionUIHandlePendingElicitationParamsResult, ) from .generated.session_events import ( PermissionRequest, @@ -250,6 +259,281 @@ class UserInputResponse(TypedDict): UserInputResponse | Awaitable[UserInputResponse], ] +# ============================================================================ +# Command Types +# ============================================================================ + + +@dataclass +class CommandContext: + """Context passed to a command handler when a command is executed.""" + + session_id: str + """Session ID where the command was invoked.""" + command: str + """The full command text (e.g. ``"/deploy production"``).""" + command_name: str + """Command name without leading ``/``.""" + args: str + """Raw argument string after the command name.""" + + +CommandHandler = Callable[[CommandContext], Awaitable[None] | None] +"""Handler invoked when a registered command is executed by a user.""" + + +@dataclass +class CommandDefinition: + """Definition of a slash command registered with the session. + + When the CLI is running with a TUI, registered commands appear as + ``/commandName`` for the user to invoke. + """ + + name: str + """Command name (without leading ``/``).""" + handler: CommandHandler + """Handler invoked when the command is executed.""" + description: str | None = None + """Human-readable description shown in command completion UI.""" + + +# ============================================================================ +# Session Capabilities +# ============================================================================ + + +class SessionUiCapabilities(TypedDict, total=False): + """UI capabilities reported by the CLI host.""" + + elicitation: bool + """Whether the host supports interactive elicitation dialogs.""" + + +class SessionCapabilities(TypedDict, total=False): + """Capabilities reported by the CLI host for this session.""" + + ui: SessionUiCapabilities + + +# ============================================================================ +# Elicitation Types (client → server) +# ============================================================================ + +ElicitationFieldValue = str | float | bool | list[str] +"""Possible value types in elicitation form content.""" + + +class ElicitationResult(TypedDict, total=False): + """Result returned from an elicitation request.""" + + action: Required[Literal["accept", "decline", "cancel"]] + """User action: ``"accept"`` (submitted), ``"decline"`` (rejected), + or ``"cancel"`` (dismissed).""" + content: dict[str, ElicitationFieldValue] + """Form values submitted by the user (present when action is ``"accept"``).""" + + +class ElicitationParams(TypedDict): + """Parameters for a raw elicitation request.""" + + message: str + """Message describing what information is needed from the user.""" + requestedSchema: dict[str, Any] + """JSON Schema describing the form fields to present.""" + + +class InputOptions(TypedDict, total=False): + """Options for the ``input()`` convenience method.""" + + title: str + """Title label for the input field.""" + description: str + """Descriptive text shown below the field.""" + minLength: int + """Minimum text length.""" + maxLength: int + """Maximum text length.""" + format: str + """Input format hint (e.g. ``"email"``, ``"uri"``, ``"date"``).""" + default: str + """Default value for the input field.""" + + +# ============================================================================ +# Elicitation Types (server → client callback) +# ============================================================================ + + +class ElicitationRequest(TypedDict, total=False): + """Request payload passed to an elicitation handler callback.""" + + message: Required[str] + """Message describing what information is needed from the user.""" + requestedSchema: dict[str, Any] + """JSON Schema describing the form fields to present.""" + mode: Literal["form", "url"] + """Elicitation mode: ``"form"`` for structured input, ``"url"`` for browser redirect.""" + elicitationSource: str + """The source that initiated the request (e.g. MCP server name).""" + url: str + """URL to open in the browser (when mode is ``"url"``).""" + + +ElicitationHandler = Callable[ + [ElicitationRequest, dict[str, str]], + ElicitationResult | Awaitable[ElicitationResult], +] +"""Handler invoked when the server dispatches an elicitation request to this client.""" + + +# ============================================================================ +# Session UI API +# ============================================================================ + + +class SessionUiApi: + """Interactive UI methods for showing dialogs to the user. + + Only available when the CLI host supports elicitation + (``session.capabilities["ui"]["elicitation"] is True``). + + Obtained via :attr:`CopilotSession.ui`. + """ + + def __init__(self, session: CopilotSession) -> None: + self._session = session + + async def elicitation(self, params: ElicitationParams) -> ElicitationResult: + """Shows a generic elicitation dialog with a custom schema. + + Args: + params: Elicitation parameters including message and requestedSchema. + + Returns: + The user's response (action + optional content). + + Raises: + RuntimeError: If the host does not support elicitation. + """ + self._session._assert_elicitation() + rpc_result = await self._session.rpc.ui.elicitation( + SessionUIElicitationParams( + message=params["message"], + requested_schema=RequestedSchema.from_dict(params["requestedSchema"]), + ) + ) + result: ElicitationResult = {"action": rpc_result.action.value} # type: ignore[typeddict-item] + if rpc_result.content is not None: + result["content"] = rpc_result.content + return result + + async def confirm(self, message: str) -> bool: + """Shows a confirmation dialog and returns the user's boolean answer. + + Args: + message: The question to ask the user. + + Returns: + ``True`` if the user accepted, ``False`` otherwise. + + Raises: + RuntimeError: If the host does not support elicitation. + """ + self._session._assert_elicitation() + rpc_result = await self._session.rpc.ui.elicitation( + SessionUIElicitationParams( + message=message, + requested_schema=RequestedSchema( + type=RequestedSchemaType.OBJECT, + properties={ + "confirmed": Property(type=PropertyType.BOOLEAN, default=True), + }, + required=["confirmed"], + ), + ) + ) + return ( + rpc_result.action == Action.ACCEPT + and rpc_result.content is not None + and rpc_result.content.get("confirmed") is True + ) + + async def select(self, message: str, options: list[str]) -> str | None: + """Shows a selection dialog with a list of options. + + Args: + message: Instruction to show the user. + options: List of choices the user can pick from. + + Returns: + The selected string, or ``None`` if the user declined/cancelled. + + Raises: + RuntimeError: If the host does not support elicitation. + """ + self._session._assert_elicitation() + rpc_result = await self._session.rpc.ui.elicitation( + SessionUIElicitationParams( + message=message, + requested_schema=RequestedSchema( + type=RequestedSchemaType.OBJECT, + properties={ + "selection": Property(type=PropertyType.STRING, enum=options), + }, + required=["selection"], + ), + ) + ) + if ( + rpc_result.action == Action.ACCEPT + and rpc_result.content is not None + and rpc_result.content.get("selection") is not None + ): + return str(rpc_result.content["selection"]) + return None + + async def input(self, message: str, options: InputOptions | None = None) -> str | None: + """Shows a text input dialog. + + Args: + message: Instruction to show the user. + options: Optional constraints for the input field. + + Returns: + The entered text, or ``None`` if the user declined/cancelled. + + Raises: + RuntimeError: If the host does not support elicitation. + """ + self._session._assert_elicitation() + field: dict[str, Any] = {"type": "string"} + if options: + for key in ("title", "description", "minLength", "maxLength", "format", "default"): + if key in options: + field[key] = options[key] + + rpc_result = await self._session.rpc.ui.elicitation( + SessionUIElicitationParams( + message=message, + requested_schema=RequestedSchema.from_dict( + { + "type": "object", + "properties": {"value": field}, + "required": ["value"], + } + ), + ) + ) + if ( + rpc_result.action == Action.ACCEPT + and rpc_result.content is not None + and rpc_result.content.get("value") is not None + ): + return str(rpc_result.content["value"]) + return None + + # ============================================================================ # Hook Types # ============================================================================ @@ -563,6 +847,12 @@ class SessionConfig(TypedDict, total=False): # are delivered. Equivalent to calling session.on(handler) immediately # after creation, but executes earlier in the lifecycle so no events are missed. on_event: Callable[[SessionEvent], None] + # Slash commands to register with the session. + # When the CLI has a TUI, each command appears as /name for the user to invoke. + commands: list[CommandDefinition] + # Handler for elicitation requests from the server. + # When provided, the server calls back to this client for form-based UI dialogs. + on_elicitation_request: ElicitationHandler class ResumeSessionConfig(TypedDict, total=False): @@ -612,6 +902,10 @@ class ResumeSessionConfig(TypedDict, total=False): # Optional event handler registered before the session.resume RPC is issued, # ensuring early events are delivered. See SessionConfig.on_event. on_event: Callable[[SessionEvent], None] + # Slash commands to register with the session. + commands: list[CommandDefinition] + # Handler for elicitation requests from the server. + on_elicitation_request: ElicitationHandler SessionEventHandler = Callable[[SessionEvent], None] @@ -676,6 +970,11 @@ def __init__( self._hooks_lock = threading.Lock() self._transform_callbacks: dict[str, SectionTransformFn] | None = None self._transform_callbacks_lock = threading.Lock() + self._command_handlers: dict[str, CommandHandler] = {} + self._command_handlers_lock = threading.Lock() + self._elicitation_handler: ElicitationHandler | None = None + self._elicitation_handler_lock = threading.Lock() + self._capabilities: SessionCapabilities = {} self._rpc: SessionRpc | None = None self._destroyed = False @@ -686,6 +985,28 @@ def rpc(self) -> SessionRpc: self._rpc = SessionRpc(self._client, self.session_id) return self._rpc + @property + def capabilities(self) -> SessionCapabilities: + """Host capabilities reported when the session was created or resumed. + + Use this to check feature support before calling capability-gated APIs. + """ + return self._capabilities + + @property + def ui(self) -> SessionUiApi: + """Interactive UI methods for showing dialogs to the user. + + Only available when the CLI host supports elicitation + (``session.capabilities.get("ui", {}).get("elicitation") is True``). + + Example: + >>> ui_caps = session.capabilities.get("ui", {}) + >>> if ui_caps.get("elicitation"): + ... ok = await session.ui.confirm("Deploy to production?") + """ + return SessionUiApi(self) + @functools.cached_property def workspace_path(self) -> pathlib.Path | None: """ @@ -909,6 +1230,47 @@ def _handle_broadcast_event(self, event: SessionEvent) -> None: self._execute_permission_and_respond(request_id, permission_request, perm_handler) ) + elif event.type == SessionEventType.COMMAND_EXECUTE: + request_id = event.data.request_id + command_name = event.data.command_name + command = event.data.command + args = event.data.args + if not request_id or not command_name: + return + asyncio.ensure_future( + self._execute_command_and_respond( + request_id, command_name, command or "", args or "" + ) + ) + + elif event.type == SessionEventType.ELICITATION_REQUESTED: + with self._elicitation_handler_lock: + handler = self._elicitation_handler + if not handler: + return + request_id = event.data.request_id + if not request_id: + return + request: ElicitationRequest = {"message": event.data.message or ""} + if event.data.requested_schema is not None: + request["requestedSchema"] = event.data.requested_schema.to_dict() + if event.data.mode is not None: + request["mode"] = event.data.mode.value + if event.data.elicitation_source is not None: + request["elicitationSource"] = event.data.elicitation_source + if event.data.url is not None: + request["url"] = event.data.url + asyncio.ensure_future(self._handle_elicitation_request(request, request_id)) + + elif event.type == SessionEventType.CAPABILITIES_CHANGED: + cap: SessionCapabilities = {} + if event.data.ui is not None: + ui_cap: SessionUiCapabilities = {} + if event.data.ui.elicitation is not None: + ui_cap["elicitation"] = event.data.ui.elicitation + cap["ui"] = ui_cap + self._capabilities = {**self._capabilities, **cap} + async def _execute_tool_and_respond( self, request_id: str, @@ -1019,6 +1381,138 @@ async def _execute_permission_and_respond( except (JsonRpcError, ProcessExitedError, OSError): pass # Connection lost or RPC error — nothing we can do + async def _execute_command_and_respond( + self, + request_id: str, + command_name: str, + command: str, + args: str, + ) -> None: + """Execute a command handler and send the result back via RPC.""" + with self._command_handlers_lock: + handler = self._command_handlers.get(command_name) + + if not handler: + try: + await self.rpc.commands.handle_pending_command( + SessionCommandsHandlePendingCommandParams( + request_id=request_id, + error=f"Unknown command: {command_name}", + ) + ) + except (JsonRpcError, ProcessExitedError, OSError): + pass # Connection lost — nothing we can do + return + + try: + ctx = CommandContext( + session_id=self.session_id, + command=command, + command_name=command_name, + args=args, + ) + result = handler(ctx) + if inspect.isawaitable(result): + await result + await self.rpc.commands.handle_pending_command( + SessionCommandsHandlePendingCommandParams(request_id=request_id) + ) + except Exception as exc: + message = str(exc) + try: + await self.rpc.commands.handle_pending_command( + SessionCommandsHandlePendingCommandParams( + request_id=request_id, + error=message, + ) + ) + except (JsonRpcError, ProcessExitedError, OSError): + pass # Connection lost — nothing we can do + + async def _handle_elicitation_request( + self, + request: ElicitationRequest, + request_id: str, + ) -> None: + """Handle an elicitation.requested broadcast event. + + Invokes the registered handler and responds via handlePendingElicitation RPC. + Auto-cancels on error so the server doesn't hang. + """ + with self._elicitation_handler_lock: + handler = self._elicitation_handler + if not handler: + return + try: + result = handler(request, {"session_id": self.session_id}) + if inspect.isawaitable(result): + result = await result + result = cast(ElicitationResult, result) + action_val = result.get("action", "cancel") + rpc_result = SessionUIHandlePendingElicitationParamsResult( + action=Action(action_val), + content=result.get("content"), + ) + await self.rpc.ui.handle_pending_elicitation( + SessionUIHandlePendingElicitationParams( + request_id=request_id, + result=rpc_result, + ) + ) + except Exception: + # Handler failed — attempt to cancel so the request doesn't hang + try: + await self.rpc.ui.handle_pending_elicitation( + SessionUIHandlePendingElicitationParams( + request_id=request_id, + result=SessionUIHandlePendingElicitationParamsResult( + action=Action.CANCEL, + ), + ) + ) + except (JsonRpcError, ProcessExitedError, OSError): + pass # Connection lost or RPC error — nothing we can do + + def _assert_elicitation(self) -> None: + """Raises if the host does not support elicitation.""" + ui_caps = self._capabilities.get("ui", {}) + if not ui_caps.get("elicitation"): + raise RuntimeError( + "Elicitation is not supported by the host. " + "Check session.capabilities before calling UI methods." + ) + + def _register_commands(self, commands: list[CommandDefinition] | None) -> None: + """Register command handlers for this session. + + Args: + commands: A list of CommandDefinition objects, or None to clear all commands. + """ + with self._command_handlers_lock: + self._command_handlers.clear() + if not commands: + return + for cmd in commands: + self._command_handlers[cmd.name] = cmd.handler + + def _register_elicitation_handler(self, handler: ElicitationHandler | None) -> None: + """Register the elicitation handler for this session. + + Args: + handler: The handler to invoke when the server dispatches an + elicitation request, or None to remove the handler. + """ + with self._elicitation_handler_lock: + self._elicitation_handler = handler + + def _set_capabilities(self, capabilities: SessionCapabilities | None) -> None: + """Set the host capabilities for this session. + + Args: + capabilities: The capabilities object from the create/resume response. + """ + self._capabilities: SessionCapabilities = capabilities if capabilities is not None else {} + def _register_tools(self, tools: list[Tool] | None) -> None: """ Register custom tool handlers for this session. @@ -1312,6 +1806,10 @@ async def disconnect(self) -> None: self._tool_handlers.clear() with self._permission_handler_lock: self._permission_handler = None + with self._command_handlers_lock: + self._command_handlers.clear() + with self._elicitation_handler_lock: + self._elicitation_handler = None async def destroy(self) -> None: """ diff --git a/python/e2e/test_commands.py b/python/e2e/test_commands.py new file mode 100644 index 000000000..f2eb7cdf1 --- /dev/null +++ b/python/e2e/test_commands.py @@ -0,0 +1,212 @@ +"""E2E Commands Tests + +Mirrors nodejs/test/e2e/commands.test.ts + +Multi-client test: a second client joining a session with commands should +trigger a ``commands.changed`` broadcast event visible to the first client. +""" + +import asyncio +import os +import shutil +import tempfile + +import pytest +import pytest_asyncio + +from copilot import CopilotClient +from copilot.client import ExternalServerConfig, SubprocessConfig +from copilot.session import CommandDefinition, PermissionHandler + +from .testharness.context import SNAPSHOTS_DIR, get_cli_path_for_tests +from .testharness.proxy import CapiProxy + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +# --------------------------------------------------------------------------- +# Multi-client context (TCP mode) — same pattern as test_multi_client.py +# --------------------------------------------------------------------------- + + +class CommandsMultiClientContext: + """Test context that manages two clients connected to the same CLI server.""" + + def __init__(self): + self.cli_path: str = "" + self.home_dir: str = "" + self.work_dir: str = "" + self.proxy_url: str = "" + self._proxy: CapiProxy | None = None + self._client1: CopilotClient | None = None + self._client2: CopilotClient | None = None + + async def setup(self): + self.cli_path = get_cli_path_for_tests() + self.home_dir = tempfile.mkdtemp(prefix="copilot-cmd-config-") + self.work_dir = tempfile.mkdtemp(prefix="copilot-cmd-work-") + + self._proxy = CapiProxy() + self.proxy_url = await self._proxy.start() + + github_token = ( + "fake-token-for-e2e-tests" if os.environ.get("GITHUB_ACTIONS") == "true" else None + ) + + # Client 1 uses TCP mode so a second client can connect + self._client1 = CopilotClient( + SubprocessConfig( + cli_path=self.cli_path, + cwd=self.work_dir, + env=self._get_env(), + use_stdio=False, + github_token=github_token, + ) + ) + + # Trigger connection to get the port + init_session = await self._client1.create_session( + on_permission_request=PermissionHandler.approve_all, + ) + await init_session.disconnect() + + actual_port = self._client1.actual_port + assert actual_port is not None + + self._client2 = CopilotClient(ExternalServerConfig(url=f"localhost:{actual_port}")) + + async def teardown(self, test_failed: bool = False): + for c in (self._client2, self._client1): + if c: + try: + await c.stop() + except Exception: + pass # Best-effort cleanup during teardown + self._client1 = self._client2 = None + + if self._proxy: + await self._proxy.stop(skip_writing_cache=test_failed) + self._proxy = None + + for d in (self.home_dir, self.work_dir): + if d and os.path.exists(d): + shutil.rmtree(d, ignore_errors=True) + + async def configure_for_test(self, test_file: str, test_name: str): + import re + + sanitized_name = re.sub(r"[^a-zA-Z0-9]", "_", test_name).lower() + snapshot_path = SNAPSHOTS_DIR / test_file / f"{sanitized_name}.yaml" + if self._proxy: + await self._proxy.configure(str(snapshot_path.resolve()), self.work_dir) + from pathlib import Path + + for d in (self.home_dir, self.work_dir): + for item in Path(d).iterdir(): + if item.is_dir(): + shutil.rmtree(item, ignore_errors=True) + else: + item.unlink(missing_ok=True) + + def _get_env(self) -> dict: + env = os.environ.copy() + env.update( + { + "COPILOT_API_URL": self.proxy_url, + "XDG_CONFIG_HOME": self.home_dir, + "XDG_STATE_HOME": self.home_dir, + } + ) + return env + + @property + def client1(self) -> CopilotClient: + assert self._client1 is not None + return self._client1 + + @property + def client2(self) -> CopilotClient: + assert self._client2 is not None + return self._client2 + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_makereport(item, call): + outcome = yield + rep = outcome.get_result() + if rep.when == "call" and rep.failed: + item.session.stash.setdefault("any_test_failed", False) + item.session.stash["any_test_failed"] = True + + +@pytest_asyncio.fixture(scope="module", loop_scope="module") +async def mctx(request): + context = CommandsMultiClientContext() + await context.setup() + yield context + any_failed = request.session.stash.get("any_test_failed", False) + await context.teardown(test_failed=any_failed) + + +@pytest_asyncio.fixture(autouse=True, loop_scope="module") +async def configure_cmd_test(request, mctx): + test_name = request.node.name + if test_name.startswith("test_"): + test_name = test_name[5:] + await mctx.configure_for_test("multi_client", test_name) + yield + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestCommands: + async def test_client_receives_commands_changed_when_another_client_joins( + self, mctx: CommandsMultiClientContext + ): + """Client receives commands.changed when another client joins with commands.""" + # Client 1 creates a session without commands + session1 = await mctx.client1.create_session( + on_permission_request=PermissionHandler.approve_all, + ) + + # Listen for the commands.changed event + commands_changed = asyncio.Event() + commands_data: dict = {} + + def on_event(event): + if event.type.value == "commands.changed": + commands_data["commands"] = getattr(event.data, "commands", None) + commands_changed.set() + + session1.on(on_event) + + # Client 2 joins the same session with commands + session2 = await mctx.client2.resume_session( + session1.session_id, + on_permission_request=PermissionHandler.approve_all, + commands=[ + CommandDefinition( + name="deploy", + description="Deploy the app", + handler=lambda ctx: None, + ), + ], + ) + + # Wait for the commands.changed event (with timeout) + await asyncio.wait_for(commands_changed.wait(), timeout=15.0) + + # Verify the event contains the deploy command + assert commands_data.get("commands") is not None + cmd_names = [c.name for c in commands_data["commands"]] + assert "deploy" in cmd_names + + await session2.disconnect() diff --git a/python/e2e/test_ui_elicitation.py b/python/e2e/test_ui_elicitation.py new file mode 100644 index 000000000..0bb5da00c --- /dev/null +++ b/python/e2e/test_ui_elicitation.py @@ -0,0 +1,58 @@ +"""E2E UI Elicitation Tests (single-client) + +Mirrors nodejs/test/e2e/ui_elicitation.test.ts — single-client scenarios. + +Uses the shared ``ctx`` fixture from conftest.py. +""" + +import pytest + +from copilot.session import ( + ElicitationRequest, + ElicitationResult, + PermissionHandler, +) + +from .testharness import E2ETestContext + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +class TestUiElicitation: + async def test_elicitation_methods_throw_in_headless_mode(self, ctx: E2ETestContext): + """Elicitation methods throw when running in headless mode.""" + session = await ctx.client.create_session( + on_permission_request=PermissionHandler.approve_all, + ) + + # The SDK spawns the CLI headless — no TUI means no elicitation support. + ui_caps = session.capabilities.get("ui", {}) + assert not ui_caps.get("elicitation") + + with pytest.raises(RuntimeError, match="not supported"): + await session.ui.confirm("test") + + async def test_session_with_elicitation_handler_reports_capability(self, ctx: E2ETestContext): + """Session created with onElicitationRequest reports elicitation capability.""" + + async def handler( + request: ElicitationRequest, invocation: dict[str, str] + ) -> ElicitationResult: + return {"action": "accept", "content": {}} + + session = await ctx.client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_elicitation_request=handler, + ) + + assert session.capabilities.get("ui", {}).get("elicitation") is True + + async def test_session_without_elicitation_handler_reports_no_capability( + self, ctx: E2ETestContext + ): + """Session created without onElicitationRequest reports no elicitation capability.""" + session = await ctx.client.create_session( + on_permission_request=PermissionHandler.approve_all, + ) + + assert session.capabilities.get("ui", {}).get("elicitation") in (False, None) diff --git a/python/e2e/test_ui_elicitation_multi_client.py b/python/e2e/test_ui_elicitation_multi_client.py new file mode 100644 index 000000000..f428d4083 --- /dev/null +++ b/python/e2e/test_ui_elicitation_multi_client.py @@ -0,0 +1,284 @@ +"""E2E UI Elicitation Tests (multi-client) + +Mirrors nodejs/test/e2e/ui_elicitation.test.ts — multi-client scenarios. + +Tests: + - capabilities.changed fires when second client joins with elicitation handler + - capabilities.changed fires when elicitation provider disconnects +""" + +import asyncio +import os +import shutil +import tempfile + +import pytest +import pytest_asyncio + +from copilot import CopilotClient +from copilot.client import ExternalServerConfig, SubprocessConfig +from copilot.session import ( + ElicitationRequest, + ElicitationResult, + PermissionHandler, +) + +from .testharness.context import SNAPSHOTS_DIR, get_cli_path_for_tests +from .testharness.proxy import CapiProxy + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +# --------------------------------------------------------------------------- +# Multi-client context (TCP mode) — same pattern as test_multi_client.py +# --------------------------------------------------------------------------- + + +class ElicitationMultiClientContext: + """Test context managing multiple clients on one CLI server.""" + + def __init__(self): + self.cli_path: str = "" + self.home_dir: str = "" + self.work_dir: str = "" + self.proxy_url: str = "" + self._proxy: CapiProxy | None = None + self._client1: CopilotClient | None = None + self._client2: CopilotClient | None = None + self._actual_port: int | None = None + + async def setup(self): + self.cli_path = get_cli_path_for_tests() + self.home_dir = tempfile.mkdtemp(prefix="copilot-elicit-config-") + self.work_dir = tempfile.mkdtemp(prefix="copilot-elicit-work-") + + self._proxy = CapiProxy() + self.proxy_url = await self._proxy.start() + + github_token = ( + "fake-token-for-e2e-tests" if os.environ.get("GITHUB_ACTIONS") == "true" else None + ) + + # Client 1 uses TCP mode so additional clients can connect + self._client1 = CopilotClient( + SubprocessConfig( + cli_path=self.cli_path, + cwd=self.work_dir, + env=self._get_env(), + use_stdio=False, + github_token=github_token, + ) + ) + + # Trigger connection to obtain the TCP port + init_session = await self._client1.create_session( + on_permission_request=PermissionHandler.approve_all, + ) + await init_session.disconnect() + + self._actual_port = self._client1.actual_port + assert self._actual_port is not None + + self._client2 = CopilotClient(ExternalServerConfig(url=f"localhost:{self._actual_port}")) + + async def teardown(self, test_failed: bool = False): + for c in (self._client2, self._client1): + if c: + try: + await c.stop() + except Exception: + pass # Best-effort cleanup during teardown + self._client1 = self._client2 = None + + if self._proxy: + await self._proxy.stop(skip_writing_cache=test_failed) + self._proxy = None + + for d in (self.home_dir, self.work_dir): + if d and os.path.exists(d): + shutil.rmtree(d, ignore_errors=True) + + async def configure_for_test(self, test_file: str, test_name: str): + import re + + sanitized_name = re.sub(r"[^a-zA-Z0-9]", "_", test_name).lower() + snapshot_path = SNAPSHOTS_DIR / test_file / f"{sanitized_name}.yaml" + if self._proxy: + await self._proxy.configure(str(snapshot_path.resolve()), self.work_dir) + from pathlib import Path + + for d in (self.home_dir, self.work_dir): + for item in Path(d).iterdir(): + if item.is_dir(): + shutil.rmtree(item, ignore_errors=True) + else: + item.unlink(missing_ok=True) + + def _get_env(self) -> dict: + env = os.environ.copy() + env.update( + { + "COPILOT_API_URL": self.proxy_url, + "XDG_CONFIG_HOME": self.home_dir, + "XDG_STATE_HOME": self.home_dir, + } + ) + return env + + def make_external_client(self) -> CopilotClient: + """Create a new external client connected to the same CLI server.""" + assert self._actual_port is not None + return CopilotClient(ExternalServerConfig(url=f"localhost:{self._actual_port}")) + + @property + def client1(self) -> CopilotClient: + assert self._client1 is not None + return self._client1 + + @property + def client2(self) -> CopilotClient: + assert self._client2 is not None + return self._client2 + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_makereport(item, call): + outcome = yield + rep = outcome.get_result() + if rep.when == "call" and rep.failed: + item.session.stash.setdefault("any_test_failed", False) + item.session.stash["any_test_failed"] = True + + +@pytest_asyncio.fixture(scope="module", loop_scope="module") +async def mctx(request): + context = ElicitationMultiClientContext() + await context.setup() + yield context + any_failed = request.session.stash.get("any_test_failed", False) + await context.teardown(test_failed=any_failed) + + +@pytest_asyncio.fixture(autouse=True, loop_scope="module") +async def configure_elicit_multi_test(request, mctx): + test_name = request.node.name + if test_name.startswith("test_"): + test_name = test_name[5:] + await mctx.configure_for_test("multi_client", test_name) + yield + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestUiElicitationMultiClient: + async def test_capabilities_changed_when_second_client_joins_with_elicitation( + self, mctx: ElicitationMultiClientContext + ): + """capabilities.changed fires when second client joins with elicitation handler.""" + # Client 1 creates session without elicitation + session1 = await mctx.client1.create_session( + on_permission_request=PermissionHandler.approve_all, + ) + assert session1.capabilities.get("ui", {}).get("elicitation") in (False, None) + + # Listen for capabilities.changed event + cap_changed = asyncio.Event() + cap_event_data: dict = {} + + def on_event(event): + if event.type.value == "capabilities.changed": + ui = getattr(event.data, "ui", None) + if ui: + cap_event_data["elicitation"] = getattr(ui, "elicitation", None) + cap_changed.set() + + unsubscribe = session1.on(on_event) + + # Client 2 joins WITH elicitation handler — triggers capabilities.changed + async def handler( + request: ElicitationRequest, invocation: dict[str, str] + ) -> ElicitationResult: + return {"action": "accept", "content": {}} + + session2 = await mctx.client2.resume_session( + session1.session_id, + on_permission_request=PermissionHandler.approve_all, + on_elicitation_request=handler, + ) + + await asyncio.wait_for(cap_changed.wait(), timeout=15.0) + unsubscribe() + + # The event should report elicitation as True + assert cap_event_data.get("elicitation") is True + + # Client 1's capabilities should have been auto-updated + assert session1.capabilities.get("ui", {}).get("elicitation") is True + + await session2.disconnect() + + async def test_capabilities_changed_when_elicitation_provider_disconnects( + self, mctx: ElicitationMultiClientContext + ): + """capabilities.changed fires when elicitation provider disconnects.""" + # Client 1 creates session without elicitation + session1 = await mctx.client1.create_session( + on_permission_request=PermissionHandler.approve_all, + ) + assert session1.capabilities.get("ui", {}).get("elicitation") in (False, None) + + # Wait for elicitation to become available + cap_enabled = asyncio.Event() + + def on_enabled(event): + if event.type.value == "capabilities.changed": + ui = getattr(event.data, "ui", None) + if ui and getattr(ui, "elicitation", None) is True: + cap_enabled.set() + + unsub_enabled = session1.on(on_enabled) + + # Use a dedicated client so we can stop it independently + client3 = mctx.make_external_client() + + async def handler( + request: ElicitationRequest, invocation: dict[str, str] + ) -> ElicitationResult: + return {"action": "accept", "content": {}} + + # Client 3 joins WITH elicitation handler + await client3.resume_session( + session1.session_id, + on_permission_request=PermissionHandler.approve_all, + on_elicitation_request=handler, + ) + + await asyncio.wait_for(cap_enabled.wait(), timeout=15.0) + unsub_enabled() + assert session1.capabilities.get("ui", {}).get("elicitation") is True + + # Now listen for the capability being removed + cap_disabled = asyncio.Event() + + def on_disabled(event): + if event.type.value == "capabilities.changed": + ui = getattr(event.data, "ui", None) + if ui and getattr(ui, "elicitation", None) is False: + cap_disabled.set() + + unsub_disabled = session1.on(on_disabled) + + # Force-stop client 3 — destroys the socket, triggering server-side cleanup + await client3.force_stop() + + await asyncio.wait_for(cap_disabled.wait(), timeout=15.0) + unsub_disabled() + assert session1.capabilities.get("ui", {}).get("elicitation") is False diff --git a/python/test_commands_and_elicitation.py b/python/test_commands_and_elicitation.py new file mode 100644 index 000000000..3c2f9c857 --- /dev/null +++ b/python/test_commands_and_elicitation.py @@ -0,0 +1,587 @@ +""" +Unit tests for Commands, UI Elicitation (client→server), and +onElicitationRequest (server→client callback) features. + +Mirrors the Node.js client.test.ts tests for these features. +""" + +import asyncio + +import pytest + +from copilot import CopilotClient +from copilot.client import SubprocessConfig +from copilot.session import ( + CommandContext, + CommandDefinition, + ElicitationRequest, + ElicitationResult, + PermissionHandler, +) +from e2e.testharness import CLI_PATH + +# ============================================================================ +# Commands +# ============================================================================ + + +class TestCommands: + @pytest.mark.asyncio + async def test_forwards_commands_in_session_create_rpc(self): + """Verifies that commands (name + description) are serialized in session.create payload.""" + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + captured: dict = {} + original_request = client._client.request + + async def mock_request(method, params): + captured[method] = params + return await original_request(method, params) + + client._client.request = mock_request + + await client.create_session( + on_permission_request=PermissionHandler.approve_all, + commands=[ + CommandDefinition( + name="deploy", + description="Deploy the app", + handler=lambda ctx: None, + ), + CommandDefinition( + name="rollback", + handler=lambda ctx: None, + ), + ], + ) + + payload = captured["session.create"] + assert payload["commands"] == [ + {"name": "deploy", "description": "Deploy the app"}, + {"name": "rollback", "description": None}, + ] + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_forwards_commands_in_session_resume_rpc(self): + """Verifies that commands are serialized in session.resume payload.""" + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + + captured: dict = {} + + async def mock_request(method, params): + captured[method] = params + if method == "session.resume": + return {"sessionId": params["sessionId"]} + raise RuntimeError(f"Unexpected method: {method}") + + client._client.request = mock_request + + await client.resume_session( + session.session_id, + on_permission_request=PermissionHandler.approve_all, + commands=[ + CommandDefinition( + name="deploy", + description="Deploy", + handler=lambda ctx: None, + ), + ], + ) + + payload = captured["session.resume"] + assert payload["commands"] == [{"name": "deploy", "description": "Deploy"}] + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_routes_command_execute_event_to_correct_handler(self): + """Verifies the command dispatch works for command.execute events.""" + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + handler_calls: list[CommandContext] = [] + + async def deploy_handler(ctx: CommandContext) -> None: + handler_calls.append(ctx) + + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + commands=[ + CommandDefinition(name="deploy", handler=deploy_handler), + ], + ) + + # Mock the RPC so handlePendingCommand doesn't fail + rpc_calls: list[tuple] = [] + original_request = client._client.request + + async def mock_request(method, params): + if method == "session.commands.handlePendingCommand": + rpc_calls.append((method, params)) + return {"success": True} + return await original_request(method, params) + + client._client.request = mock_request + + # Simulate a command.execute broadcast event + from copilot.generated.session_events import ( + Data, + SessionEvent, + SessionEventType, + ) + + event = SessionEvent( + data=Data( + request_id="req-1", + command="/deploy production", + command_name="deploy", + args="production", + ), + id="evt-1", + timestamp="2025-01-01T00:00:00Z", + type=SessionEventType.COMMAND_EXECUTE, + ephemeral=True, + parent_id=None, + ) + session._dispatch_event(event) + + # Wait for async handler + await asyncio.sleep(0.2) + + assert len(handler_calls) == 1 + assert handler_calls[0].session_id == session.session_id + assert handler_calls[0].command == "/deploy production" + assert handler_calls[0].command_name == "deploy" + assert handler_calls[0].args == "production" + + # Verify handlePendingCommand was called + assert len(rpc_calls) >= 1 + assert rpc_calls[0][1]["requestId"] == "req-1" + # No error key means success + assert "error" not in rpc_calls[0][1] or rpc_calls[0][1].get("error") is None + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_sends_error_when_command_handler_throws(self): + """Verifies error is sent via RPC when a command handler raises.""" + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + + def fail_handler(ctx: CommandContext) -> None: + raise RuntimeError("deploy failed") + + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + commands=[ + CommandDefinition(name="fail", handler=fail_handler), + ], + ) + + rpc_calls: list[tuple] = [] + original_request = client._client.request + + async def mock_request(method, params): + if method == "session.commands.handlePendingCommand": + rpc_calls.append((method, params)) + return {"success": True} + return await original_request(method, params) + + client._client.request = mock_request + + from copilot.generated.session_events import ( + Data, + SessionEvent, + SessionEventType, + ) + + event = SessionEvent( + data=Data( + request_id="req-2", + command="/fail", + command_name="fail", + args="", + ), + id="evt-2", + timestamp="2025-01-01T00:00:00Z", + type=SessionEventType.COMMAND_EXECUTE, + ephemeral=True, + parent_id=None, + ) + session._dispatch_event(event) + + await asyncio.sleep(0.2) + + assert len(rpc_calls) >= 1 + assert rpc_calls[0][1]["requestId"] == "req-2" + assert "deploy failed" in rpc_calls[0][1]["error"] + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_sends_error_for_unknown_command(self): + """Verifies error is sent via RPC for an unrecognized command.""" + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + commands=[ + CommandDefinition(name="deploy", handler=lambda ctx: None), + ], + ) + + rpc_calls: list[tuple] = [] + original_request = client._client.request + + async def mock_request(method, params): + if method == "session.commands.handlePendingCommand": + rpc_calls.append((method, params)) + return {"success": True} + return await original_request(method, params) + + client._client.request = mock_request + + from copilot.generated.session_events import ( + Data, + SessionEvent, + SessionEventType, + ) + + event = SessionEvent( + data=Data( + request_id="req-3", + command="/unknown", + command_name="unknown", + args="", + ), + id="evt-3", + timestamp="2025-01-01T00:00:00Z", + type=SessionEventType.COMMAND_EXECUTE, + ephemeral=True, + parent_id=None, + ) + session._dispatch_event(event) + + await asyncio.sleep(0.2) + + assert len(rpc_calls) >= 1 + assert rpc_calls[0][1]["requestId"] == "req-3" + assert "Unknown command" in rpc_calls[0][1]["error"] + finally: + await client.force_stop() + + +# ============================================================================ +# UI Elicitation (client → server) +# ============================================================================ + + +class TestUiElicitation: + @pytest.mark.asyncio + async def test_reads_capabilities_from_session_create_response(self): + """Verifies capabilities are parsed from session.create response.""" + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + original_request = client._client.request + + async def mock_request(method, params): + if method == "session.create": + result = await original_request(method, params) + return {**result, "capabilities": {"ui": {"elicitation": True}}} + return await original_request(method, params) + + client._client.request = mock_request + + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + assert session.capabilities == {"ui": {"elicitation": True}} + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_defaults_capabilities_when_not_injected(self): + """Verifies capabilities default to empty when server returns none.""" + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + # CLI returns actual capabilities; in headless mode, elicitation is + # either False or absent. Just verify we don't crash. + ui_caps = session.capabilities.get("ui", {}) + assert ui_caps.get("elicitation") in (False, None, True) + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_elicitation_throws_when_capability_is_missing(self): + """Verifies that UI methods throw when elicitation is not supported.""" + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + # Force capabilities to not support elicitation + session._set_capabilities({}) + + with pytest.raises(RuntimeError, match="not supported"): + await session.ui.elicitation( + { + "message": "Enter name", + "requestedSchema": { + "type": "object", + "properties": {"name": {"type": "string", "minLength": 1}}, + "required": ["name"], + }, + } + ) + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_confirm_throws_when_capability_is_missing(self): + """Verifies confirm throws when elicitation is not supported.""" + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + session._set_capabilities({}) + + with pytest.raises(RuntimeError, match="not supported"): + await session.ui.confirm("Deploy?") + finally: + await client.force_stop() + + +# ============================================================================ +# onElicitationRequest (server → client callback) +# ============================================================================ + + +class TestOnElicitationRequest: + @pytest.mark.asyncio + async def test_sends_request_elicitation_flag_when_handler_provided(self): + """Verifies requestElicitation=true is sent when onElicitationRequest is provided.""" + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + captured: dict = {} + original_request = client._client.request + + async def mock_request(method, params): + captured[method] = params + return await original_request(method, params) + + client._client.request = mock_request + + async def elicitation_handler( + request: ElicitationRequest, invocation: dict[str, str] + ) -> ElicitationResult: + return {"action": "accept", "content": {}} + + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_elicitation_request=elicitation_handler, + ) + assert session is not None + + payload = captured["session.create"] + assert payload["requestElicitation"] is True + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_does_not_send_request_elicitation_when_no_handler(self): + """Verifies requestElicitation=false when no handler is provided.""" + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + captured: dict = {} + original_request = client._client.request + + async def mock_request(method, params): + captured[method] = params + return await original_request(method, params) + + client._client.request = mock_request + + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + ) + assert session is not None + + payload = captured["session.create"] + assert payload["requestElicitation"] is False + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_sends_cancel_when_elicitation_handler_throws(self): + """Verifies auto-cancel when the elicitation handler raises.""" + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + + async def bad_handler( + request: ElicitationRequest, invocation: dict[str, str] + ) -> ElicitationResult: + raise RuntimeError("handler exploded") + + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_elicitation_request=bad_handler, + ) + + rpc_calls: list[tuple] = [] + original_request = client._client.request + + async def mock_request(method, params): + if method == "session.ui.handlePendingElicitation": + rpc_calls.append((method, params)) + return {"success": True} + return await original_request(method, params) + + client._client.request = mock_request + + # Call _handle_elicitation_request directly (as Node.js test does) + await session._handle_elicitation_request({"message": "Pick a color"}, "req-123") + + assert len(rpc_calls) >= 1 + cancel_call = next( + (call for call in rpc_calls if call[1].get("result", {}).get("action") == "cancel"), + None, + ) + assert cancel_call is not None + assert cancel_call[1]["requestId"] == "req-123" + assert cancel_call[1]["result"]["action"] == "cancel" + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_dispatches_elicitation_requested_event_to_handler(self): + """Verifies that an elicitation.requested event dispatches to the handler.""" + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + handler_calls: list = [] + + async def elicitation_handler( + request: ElicitationRequest, invocation: dict[str, str] + ) -> ElicitationResult: + handler_calls.append(request) + return {"action": "accept", "content": {"color": "blue"}} + + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_elicitation_request=elicitation_handler, + ) + + rpc_calls: list[tuple] = [] + original_request = client._client.request + + async def mock_request(method, params): + if method == "session.ui.handlePendingElicitation": + rpc_calls.append((method, params)) + return {"success": True} + return await original_request(method, params) + + client._client.request = mock_request + + from copilot.generated.session_events import ( + Data, + SessionEvent, + SessionEventType, + ) + + event = SessionEvent( + data=Data( + request_id="req-elicit-1", + message="Pick a color", + ), + id="evt-elicit-1", + timestamp="2025-01-01T00:00:00Z", + type=SessionEventType.ELICITATION_REQUESTED, + ephemeral=True, + parent_id=None, + ) + session._dispatch_event(event) + + await asyncio.sleep(0.2) + + assert len(handler_calls) == 1 + assert handler_calls[0]["message"] == "Pick a color" + + assert len(rpc_calls) >= 1 + assert rpc_calls[0][1]["requestId"] == "req-elicit-1" + assert rpc_calls[0][1]["result"]["action"] == "accept" + finally: + await client.force_stop() + + +# ============================================================================ +# Capabilities changed event +# ============================================================================ + + +class TestCapabilitiesChanged: + @pytest.mark.asyncio + async def test_capabilities_changed_event_updates_session(self): + """Verifies that a capabilities.changed event updates session capabilities.""" + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH)) + await client.start() + + try: + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + session._set_capabilities({}) + + from copilot.generated.session_events import ( + UI, + Data, + SessionEvent, + SessionEventType, + ) + + event = SessionEvent( + data=Data(ui=UI(elicitation=True)), + id="evt-cap-1", + timestamp="2025-01-01T00:00:00Z", + type=SessionEventType.CAPABILITIES_CHANGED, + ephemeral=True, + parent_id=None, + ) + session._dispatch_event(event) + + assert session.capabilities.get("ui", {}).get("elicitation") is True + finally: + await client.force_stop() diff --git a/test/snapshots/commands/forwards_commands_in_session_create.yaml b/test/snapshots/commands/forwards_commands_in_session_create.yaml new file mode 100644 index 000000000..e00c26279 --- /dev/null +++ b/test/snapshots/commands/forwards_commands_in_session_create.yaml @@ -0,0 +1,4 @@ +models: + - claude-sonnet-4.5 +conversations: [] + diff --git a/test/snapshots/commands/forwards_commands_in_session_resume.yaml b/test/snapshots/commands/forwards_commands_in_session_resume.yaml new file mode 100644 index 000000000..e00c26279 --- /dev/null +++ b/test/snapshots/commands/forwards_commands_in_session_resume.yaml @@ -0,0 +1,4 @@ +models: + - claude-sonnet-4.5 +conversations: [] + diff --git a/test/snapshots/commands/session_with_no_commands_creates_successfully.yaml b/test/snapshots/commands/session_with_no_commands_creates_successfully.yaml new file mode 100644 index 000000000..e00c26279 --- /dev/null +++ b/test/snapshots/commands/session_with_no_commands_creates_successfully.yaml @@ -0,0 +1,4 @@ +models: + - claude-sonnet-4.5 +conversations: [] + diff --git a/test/snapshots/elicitation/defaults_capabilities_when_not_provided.yaml b/test/snapshots/elicitation/defaults_capabilities_when_not_provided.yaml new file mode 100644 index 000000000..e00c26279 --- /dev/null +++ b/test/snapshots/elicitation/defaults_capabilities_when_not_provided.yaml @@ -0,0 +1,4 @@ +models: + - claude-sonnet-4.5 +conversations: [] + diff --git a/test/snapshots/elicitation/elicitation_throws_when_capability_is_missing.yaml b/test/snapshots/elicitation/elicitation_throws_when_capability_is_missing.yaml new file mode 100644 index 000000000..e00c26279 --- /dev/null +++ b/test/snapshots/elicitation/elicitation_throws_when_capability_is_missing.yaml @@ -0,0 +1,4 @@ +models: + - claude-sonnet-4.5 +conversations: [] + diff --git a/test/snapshots/elicitation/sends_requestelicitation_when_handler_provided.yaml b/test/snapshots/elicitation/sends_requestelicitation_when_handler_provided.yaml new file mode 100644 index 000000000..e00c26279 --- /dev/null +++ b/test/snapshots/elicitation/sends_requestelicitation_when_handler_provided.yaml @@ -0,0 +1,4 @@ +models: + - claude-sonnet-4.5 +conversations: [] + diff --git a/test/snapshots/elicitation/session_with_elicitation_handler_reports_elicitation_capability.yaml b/test/snapshots/elicitation/session_with_elicitation_handler_reports_elicitation_capability.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/elicitation/session_with_elicitation_handler_reports_elicitation_capability.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/elicitation/session_with_elicitationhandler_reports_elicitation_capability.yaml b/test/snapshots/elicitation/session_with_elicitationhandler_reports_elicitation_capability.yaml new file mode 100644 index 000000000..e00c26279 --- /dev/null +++ b/test/snapshots/elicitation/session_with_elicitationhandler_reports_elicitation_capability.yaml @@ -0,0 +1,4 @@ +models: + - claude-sonnet-4.5 +conversations: [] + diff --git a/test/snapshots/elicitation/session_without_elicitation_handler_reports_no_capability.yaml b/test/snapshots/elicitation/session_without_elicitation_handler_reports_no_capability.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/elicitation/session_without_elicitation_handler_reports_no_capability.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/elicitation/session_without_elicitationhandler_creates_successfully.yaml b/test/snapshots/elicitation/session_without_elicitationhandler_creates_successfully.yaml new file mode 100644 index 000000000..e00c26279 --- /dev/null +++ b/test/snapshots/elicitation/session_without_elicitationhandler_creates_successfully.yaml @@ -0,0 +1,4 @@ +models: + - claude-sonnet-4.5 +conversations: [] + diff --git a/test/snapshots/elicitation/session_without_elicitationhandler_reports_no_capability.yaml b/test/snapshots/elicitation/session_without_elicitationhandler_reports_no_capability.yaml new file mode 100644 index 000000000..e00c26279 --- /dev/null +++ b/test/snapshots/elicitation/session_without_elicitationhandler_reports_no_capability.yaml @@ -0,0 +1,4 @@ +models: + - claude-sonnet-4.5 +conversations: [] + diff --git a/test/snapshots/multi_client/capabilities_changed_fires_when_elicitation_provider_disconnects.yaml b/test/snapshots/multi_client/capabilities_changed_fires_when_elicitation_provider_disconnects.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/multi_client/capabilities_changed_fires_when_elicitation_provider_disconnects.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/multi_client/capabilities_changed_fires_when_second_client_joins_with_elicitation_handler.yaml b/test/snapshots/multi_client/capabilities_changed_fires_when_second_client_joins_with_elicitation_handler.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/multi_client/capabilities_changed_fires_when_second_client_joins_with_elicitation_handler.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: [] diff --git a/test/snapshots/multi_client/client_receives_commands_changed_when_another_client_joins_with_commands.yaml b/test/snapshots/multi_client/client_receives_commands_changed_when_another_client_joins_with_commands.yaml new file mode 100644 index 000000000..056351ddb --- /dev/null +++ b/test/snapshots/multi_client/client_receives_commands_changed_when_another_client_joins_with_commands.yaml @@ -0,0 +1,3 @@ +models: + - claude-sonnet-4.5 +conversations: []