From 37de12662bc253760b7dd59b77d9c575d575f832 Mon Sep 17 00:00:00 2001 From: Patrick Nikoletich Date: Sun, 8 Mar 2026 14:36:48 -0700 Subject: [PATCH] feat: add onListModels handler to CopilotClientOptions for BYOK mode Add an optional onListModels handler to CopilotClientOptions across all 4 SDKs (Node, Python, Go, .NET). When provided, client.listModels() calls the handler instead of sending the models.list RPC to the CLI server. This enables BYOK users to return their provider's available models in the standard ModelInfo format. - Handler completely replaces CLI RPC when set (no fallback) - Results cached identically to CLI path (same locking/thread-safety) - No connection required when handler is provided - Supports both sync and async handlers - 10 new unit tests across all SDKs - Updated BYOK docs with usage examples in all 4 languages Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- docs/auth/byok.md | 110 ++++++++++++++++++++++++++++++++++++ dotnet/src/Client.cs | 29 +++++++--- dotnet/src/Types.cs | 9 +++ dotnet/test/ClientTests.cs | 100 +++++++++++++++++++++++++++++++++ go/client.go | 51 +++++++++++------ go/client_test.go | 60 ++++++++++++++++++++ go/types.go | 10 +++- nodejs/src/client.ts | 35 ++++++++---- nodejs/src/types.ts | 8 +++ nodejs/test/client.test.ts | 89 ++++++++++++++++++++++++++++- python/copilot/client.py | 34 ++++++++---- python/copilot/types.py | 5 ++ python/test_client.py | 111 +++++++++++++++++++++++++++++++++++++ 13 files changed, 600 insertions(+), 51 deletions(-) diff --git a/docs/auth/byok.md b/docs/auth/byok.md index 49d2452d9..df334508d 100644 --- a/docs/auth/byok.md +++ b/docs/auth/byok.md @@ -306,6 +306,116 @@ provider: { > **Note:** The `bearerToken` option accepts a **static token string** only. The SDK does not refresh this token automatically. If your token expires, requests will fail and you'll need to create a new session with a fresh token. +## Custom Model Listing + +When using BYOK, the CLI server may not know which models your provider supports. You can supply a custom `onListModels` handler at the client level so that `client.listModels()` returns your provider's models in the standard `ModelInfo` format. This lets downstream consumers discover available models without querying the CLI. + +
+Node.js / TypeScript + +```typescript +import { CopilotClient } from "@github/copilot-sdk"; +import type { ModelInfo } from "@github/copilot-sdk"; + +const client = new CopilotClient({ + onListModels: () => [ + { + id: "my-custom-model", + name: "My Custom Model", + capabilities: { + supports: { vision: false, reasoningEffort: false }, + limits: { max_context_window_tokens: 128000 }, + }, + }, + ], +}); +``` + +
+ +
+Python + +```python +from copilot import CopilotClient +from copilot.types import ModelInfo, ModelCapabilities, ModelSupports, ModelLimits + +client = CopilotClient({ + "on_list_models": lambda: [ + ModelInfo( + id="my-custom-model", + name="My Custom Model", + capabilities=ModelCapabilities( + supports=ModelSupports(vision=False, reasoning_effort=False), + limits=ModelLimits(max_context_window_tokens=128000), + ), + ) + ], +}) +``` + +
+ +
+Go + +```go +package main + +import ( + "context" + copilot "github.com/github/copilot-sdk/go" +) + +func main() { + client := copilot.NewClient(&copilot.ClientOptions{ + OnListModels: func(ctx context.Context) ([]copilot.ModelInfo, error) { + return []copilot.ModelInfo{ + { + ID: "my-custom-model", + Name: "My Custom Model", + Capabilities: copilot.ModelCapabilities{ + Supports: copilot.ModelSupports{Vision: false, ReasoningEffort: false}, + Limits: copilot.ModelLimits{MaxContextWindowTokens: 128000}, + }, + }, + }, nil + }, + }) + _ = client +} +``` + +
+ +
+.NET + +```csharp +using GitHub.Copilot.SDK; + +var client = new CopilotClient(new CopilotClientOptions +{ + OnListModels = (ct) => Task.FromResult(new List + { + new() + { + Id = "my-custom-model", + Name = "My Custom Model", + Capabilities = new ModelCapabilities + { + Supports = new ModelSupports { Vision = false, ReasoningEffort = false }, + Limits = new ModelLimits { MaxContextWindowTokens = 128000 } + } + } + }) +}); +``` + +
+ +Results are cached after the first call, just like the default behavior. The handler completely replaces the CLI's `models.list` RPC — no fallback to the server occurs. + ## Limitations When using BYOK, be aware of these limitations: diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 91b6353ff..1b4da2ffb 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -70,6 +70,7 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable private int? _negotiatedProtocolVersion; private List? _modelsCache; private readonly SemaphoreSlim _modelsCacheLock = new(1, 1); + private readonly Func>>? _onListModels; private readonly List> _lifecycleHandlers = []; private readonly Dictionary>> _typedLifecycleHandlers = []; private readonly object _lifecycleHandlersLock = new(); @@ -136,6 +137,7 @@ public CopilotClient(CopilotClientOptions? options = null) } _logger = _options.Logger ?? NullLogger.Instance; + _onListModels = _options.OnListModels; // Parse CliUrl if provided if (!string.IsNullOrEmpty(_options.CliUrl)) @@ -624,9 +626,6 @@ public async Task GetAuthStatusAsync(CancellationToken ca /// Thrown when the client is not connected or not authenticated. public async Task> ListModelsAsync(CancellationToken cancellationToken = default) { - var connection = await EnsureConnectedAsync(cancellationToken); - - // Use semaphore for async locking to prevent race condition with concurrent calls await _modelsCacheLock.WaitAsync(cancellationToken); try { @@ -636,14 +635,26 @@ public async Task> ListModelsAsync(CancellationToken cancellatio return [.. _modelsCache]; // Return a copy to prevent cache mutation } - // Cache miss - fetch from backend while holding lock - var response = await InvokeRpcAsync( - connection.Rpc, "models.list", [], cancellationToken); + List models; + if (_onListModels is not null) + { + // Use custom handler instead of CLI RPC + models = await _onListModels(cancellationToken); + } + else + { + var connection = await EnsureConnectedAsync(cancellationToken); + + // Cache miss - fetch from backend while holding lock + var response = await InvokeRpcAsync( + connection.Rpc, "models.list", [], cancellationToken); + models = response.Models; + } - // Update cache before releasing lock - _modelsCache = response.Models; + // Update cache before releasing lock (copy to prevent external mutation) + _modelsCache = [.. models]; - return [.. response.Models]; // Return a copy to prevent cache mutation + return [.. models]; // Return a copy to prevent cache mutation } finally { diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 52d870b80..a132e4818 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -63,6 +63,7 @@ protected CopilotClientOptions(CopilotClientOptions? other) Port = other.Port; UseLoggedInUser = other.UseLoggedInUser; UseStdio = other.UseStdio; + OnListModels = other.OnListModels; } /// @@ -136,6 +137,14 @@ public string? GithubToken /// public bool? UseLoggedInUser { get; set; } + /// + /// Custom handler for listing available models. + /// When provided, ListModelsAsync() calls this handler instead of + /// querying the CLI server. Useful in BYOK mode to return models + /// available from your custom provider. + /// + public Func>>? OnListModels { get; set; } + /// /// Creates a shallow clone of this instance. /// diff --git a/dotnet/test/ClientTests.cs b/dotnet/test/ClientTests.cs index 3c3f3bdaa..6c70ffaa3 100644 --- a/dotnet/test/ClientTests.cs +++ b/dotnet/test/ClientTests.cs @@ -274,4 +274,104 @@ public async Task Should_Throw_When_ResumeSession_Called_Without_PermissionHandl Assert.Contains("OnPermissionRequest", ex.Message); Assert.Contains("is required", ex.Message); } + + [Fact] + public async Task ListModels_WithCustomHandler_CallsHandler() + { + var customModels = new List + { + new() + { + Id = "my-custom-model", + Name = "My Custom Model", + Capabilities = new ModelCapabilities + { + Supports = new ModelSupports { Vision = false, ReasoningEffort = false }, + Limits = new ModelLimits { MaxContextWindowTokens = 128000 } + } + } + }; + + var callCount = 0; + await using var client = new CopilotClient(new CopilotClientOptions + { + OnListModels = (ct) => + { + callCount++; + return Task.FromResult(customModels); + } + }); + await client.StartAsync(); + + var models = await client.ListModelsAsync(); + Assert.Equal(1, callCount); + Assert.Single(models); + Assert.Equal("my-custom-model", models[0].Id); + } + + [Fact] + public async Task ListModels_WithCustomHandler_CachesResults() + { + var customModels = new List + { + new() + { + Id = "cached-model", + Name = "Cached Model", + Capabilities = new ModelCapabilities + { + Supports = new ModelSupports { Vision = false, ReasoningEffort = false }, + Limits = new ModelLimits { MaxContextWindowTokens = 128000 } + } + } + }; + + var callCount = 0; + await using var client = new CopilotClient(new CopilotClientOptions + { + OnListModels = (ct) => + { + callCount++; + return Task.FromResult(customModels); + } + }); + await client.StartAsync(); + + await client.ListModelsAsync(); + await client.ListModelsAsync(); + Assert.Equal(1, callCount); // Only called once due to caching + } + + [Fact] + public async Task ListModels_WithCustomHandler_WorksWithoutStart() + { + var customModels = new List + { + new() + { + Id = "no-start-model", + Name = "No Start Model", + Capabilities = new ModelCapabilities + { + Supports = new ModelSupports { Vision = false, ReasoningEffort = false }, + Limits = new ModelLimits { MaxContextWindowTokens = 128000 } + } + } + }; + + var callCount = 0; + await using var client = new CopilotClient(new CopilotClientOptions + { + OnListModels = (ct) => + { + callCount++; + return Task.FromResult(customModels); + } + }); + + var models = await client.ListModelsAsync(); + Assert.Equal(1, callCount); + Assert.Single(models); + Assert.Equal("no-start-model", models[0].Id); + } } diff --git a/go/client.go b/go/client.go index 3c1fb28cf..d440b49b4 100644 --- a/go/client.go +++ b/go/client.go @@ -92,6 +92,7 @@ type Client struct { processErrorPtr *error osProcess atomic.Pointer[os.Process] negotiatedProtocolVersion int + onListModels func(ctx context.Context) ([]ModelInfo, error) // RPC provides typed server-scoped RPC methods. // This field is nil until the client is connected via Start(). @@ -188,6 +189,9 @@ func NewClient(options *ClientOptions) *Client { if options.UseLoggedInUser != nil { opts.UseLoggedInUser = options.UseLoggedInUser } + if options.OnListModels != nil { + client.onListModels = options.OnListModels + } } // Default Env to current environment if not set @@ -1035,40 +1039,51 @@ func (c *Client) GetAuthStatus(ctx context.Context) (*GetAuthStatusResponse, err // Results are cached after the first successful call to avoid rate limiting. // The cache is cleared when the client disconnects. func (c *Client) ListModels(ctx context.Context) ([]ModelInfo, error) { - if c.client == nil { - return nil, fmt.Errorf("client not connected") - } - // Use mutex for locking to prevent race condition with concurrent calls c.modelsCacheMux.Lock() defer c.modelsCacheMux.Unlock() // Check cache (already inside lock) if c.modelsCache != nil { - // Return a copy to prevent cache mutation result := make([]ModelInfo, len(c.modelsCache)) copy(result, c.modelsCache) return result, nil } - // Cache miss - fetch from backend while holding lock - result, err := c.client.Request("models.list", listModelsRequest{}) - if err != nil { - return nil, err - } + var models []ModelInfo + if c.onListModels != nil { + // Use custom handler instead of CLI RPC + var err error + models, err = c.onListModels(ctx) + if err != nil { + return nil, err + } + } else { + if c.client == nil { + return nil, fmt.Errorf("client not connected") + } + // Cache miss - fetch from backend while holding lock + result, err := c.client.Request("models.list", listModelsRequest{}) + if err != nil { + return nil, err + } - var response listModelsResponse - if err := json.Unmarshal(result, &response); err != nil { - return nil, fmt.Errorf("failed to unmarshal models response: %w", err) + var response listModelsResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal models response: %w", err) + } + models = response.Models } - // Update cache before releasing lock - c.modelsCache = response.Models + // Update cache before releasing lock (copy to prevent external mutation) + cache := make([]ModelInfo, len(models)) + copy(cache, models) + c.modelsCache = cache // Return a copy to prevent cache mutation - models := make([]ModelInfo, len(response.Models)) - copy(models, response.Models) - return models, nil + result := make([]ModelInfo, len(models)) + copy(result, models) + return result, nil } // minProtocolVersion is the minimum protocol version this SDK can communicate with. diff --git a/go/client_test.go b/go/client_test.go index 76efe98ba..601215cbe 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -1,6 +1,7 @@ package copilot import ( + "context" "encoding/json" "os" "path/filepath" @@ -548,6 +549,65 @@ func TestClient_ResumeSession_RequiresPermissionHandler(t *testing.T) { }) } +func TestListModelsWithCustomHandler(t *testing.T) { + customModels := []ModelInfo{ + { + ID: "my-custom-model", + Name: "My Custom Model", + Capabilities: ModelCapabilities{ + Supports: ModelSupports{Vision: false, ReasoningEffort: false}, + Limits: ModelLimits{MaxContextWindowTokens: 128000}, + }, + }, + } + + callCount := 0 + handler := func(ctx context.Context) ([]ModelInfo, error) { + callCount++ + return customModels, nil + } + + client := NewClient(&ClientOptions{OnListModels: handler}) + + models, err := client.ListModels(t.Context()) + if err != nil { + t.Fatalf("ListModels failed: %v", err) + } + if callCount != 1 { + t.Errorf("expected handler called once, got %d", callCount) + } + if len(models) != 1 || models[0].ID != "my-custom-model" { + t.Errorf("unexpected models: %+v", models) + } +} + +func TestListModelsHandlerCachesResults(t *testing.T) { + customModels := []ModelInfo{ + { + ID: "cached-model", + Name: "Cached Model", + Capabilities: ModelCapabilities{ + Supports: ModelSupports{Vision: false, ReasoningEffort: false}, + Limits: ModelLimits{MaxContextWindowTokens: 128000}, + }, + }, + } + + callCount := 0 + handler := func(ctx context.Context) ([]ModelInfo, error) { + callCount++ + return customModels, nil + } + + client := NewClient(&ClientOptions{OnListModels: handler}) + + _, _ = client.ListModels(t.Context()) + _, _ = client.ListModels(t.Context()) + if callCount != 1 { + t.Errorf("expected handler called once due to caching, got %d", callCount) + } +} + func TestClient_StartStopRace(t *testing.T) { cliPath := findCLIPathForTest() if cliPath == "" { diff --git a/go/types.go b/go/types.go index 7970b2fe0..eaee2fb11 100644 --- a/go/types.go +++ b/go/types.go @@ -1,6 +1,9 @@ package copilot -import "encoding/json" +import ( + "context" + "encoding/json" +) // ConnectionState represents the client connection state type ConnectionState string @@ -54,6 +57,11 @@ type ClientOptions struct { // Default: true (but defaults to false when GitHubToken is provided). // Use Bool(false) to explicitly disable. UseLoggedInUser *bool + // OnListModels is a custom handler for listing available models. + // When provided, client.ListModels() calls this handler instead of + // querying the CLI server. Useful in BYOK mode to return models + // available from your custom provider. + OnListModels func(ctx context.Context) ([]ModelInfo, error) } // Bool returns a pointer to the given bool value. diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 1108edaea..8cc79bf56 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -141,7 +141,7 @@ export class CopilotClient { private sessions: Map = new Map(); private stderrBuffer: string = ""; // Captures CLI stderr for error messages private options: Required< - Omit + Omit > & { cliUrl?: string; githubToken?: string; @@ -149,6 +149,7 @@ export class CopilotClient { }; private isExternalServer: boolean = false; private forceStopping: boolean = false; + private onListModels?: () => Promise | ModelInfo[]; private modelsCache: ModelInfo[] | null = null; private modelsCacheLock: Promise = Promise.resolve(); private sessionLifecycleHandlers: Set = new Set(); @@ -226,6 +227,8 @@ export class CopilotClient { this.isExternalServer = true; } + this.onListModels = options.onListModels; + this.options = { cliPath: options.cliPath || getBundledCliPath(), cliArgs: options.cliArgs ?? [], @@ -751,16 +754,15 @@ export class CopilotClient { /** * List available models with their metadata. * + * If an `onListModels` handler was provided in the client options, + * it is called instead of querying the CLI server. + * * Results are cached after the first successful call to avoid rate limiting. * The cache is cleared when the client disconnects. * - * @throws Error if not authenticated + * @throws Error if not connected (when no custom handler is set) */ async listModels(): Promise { - if (!this.connection) { - throw new Error("Client not connected"); - } - // Use promise-based locking to prevent race condition with concurrent calls await this.modelsCacheLock; @@ -775,13 +777,22 @@ export class CopilotClient { return [...this.modelsCache]; // Return a copy to prevent cache mutation } - // Cache miss - fetch from backend while holding lock - const result = await this.connection.sendRequest("models.list", {}); - const response = result as { models: ModelInfo[] }; - const models = response.models; + let models: ModelInfo[]; + if (this.onListModels) { + // Use custom handler instead of CLI RPC + models = await this.onListModels(); + } else { + if (!this.connection) { + throw new Error("Client not connected"); + } + // Cache miss - fetch from backend while holding lock + const result = await this.connection.sendRequest("models.list", {}); + const response = result as { models: ModelInfo[] }; + models = response.models; + } - // Update cache before releasing lock - this.modelsCache = models; + // Update cache before releasing lock (copy to prevent external mutation) + this.modelsCache = [...models]; return [...models]; // Return a copy to prevent cache mutation } finally { diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index acda50fef..69c29396a 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -96,6 +96,14 @@ export interface CopilotClientOptions { * @default true (but defaults to false when githubToken is provided) */ useLoggedInUser?: boolean; + + /** + * Custom handler for listing available models. + * When provided, client.listModels() calls this handler instead of + * querying the CLI server. Useful in BYOK mode to return models + * available from your custom provider. + */ + onListModels?: () => Promise | ModelInfo[]; } /** diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index 22f969998..ef227b698 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { describe, expect, it, onTestFinished, vi } from "vitest"; -import { approveAll, CopilotClient } from "../src/index.js"; +import { approveAll, CopilotClient, type ModelInfo } from "../src/index.js"; // This file is for unit tests. Where relevant, prefer to add e2e tests in e2e/*.test.ts instead @@ -388,4 +388,91 @@ describe("CopilotClient", () => { spy.mockRestore(); }); }); + + describe("onListModels", () => { + it("calls onListModels handler instead of RPC when provided", async () => { + const customModels: ModelInfo[] = [ + { + id: "my-custom-model", + name: "My Custom Model", + capabilities: { + supports: { vision: false, reasoningEffort: false }, + limits: { max_context_window_tokens: 128000 }, + }, + }, + ]; + + const handler = vi.fn().mockReturnValue(customModels); + const client = new CopilotClient({ onListModels: handler }); + await client.start(); + onTestFinished(() => client.forceStop()); + + const models = await client.listModels(); + expect(handler).toHaveBeenCalledTimes(1); + expect(models).toEqual(customModels); + }); + + it("caches onListModels results on subsequent calls", async () => { + const customModels: ModelInfo[] = [ + { + id: "cached-model", + name: "Cached Model", + capabilities: { + supports: { vision: false, reasoningEffort: false }, + limits: { max_context_window_tokens: 128000 }, + }, + }, + ]; + + const handler = vi.fn().mockReturnValue(customModels); + const client = new CopilotClient({ onListModels: handler }); + await client.start(); + onTestFinished(() => client.forceStop()); + + await client.listModels(); + await client.listModels(); + expect(handler).toHaveBeenCalledTimes(1); // Only called once due to caching + }); + + it("supports async onListModels handler", async () => { + const customModels: ModelInfo[] = [ + { + id: "async-model", + name: "Async Model", + capabilities: { + supports: { vision: false, reasoningEffort: false }, + limits: { max_context_window_tokens: 128000 }, + }, + }, + ]; + + const handler = vi.fn().mockResolvedValue(customModels); + const client = new CopilotClient({ onListModels: handler }); + await client.start(); + onTestFinished(() => client.forceStop()); + + const models = await client.listModels(); + expect(models).toEqual(customModels); + }); + + it("does not require client.start when onListModels is provided", async () => { + const customModels: ModelInfo[] = [ + { + id: "no-start-model", + name: "No Start Model", + capabilities: { + supports: { vision: false, reasoningEffort: false }, + limits: { max_context_window_tokens: 128000 }, + }, + }, + ]; + + const handler = vi.fn().mockReturnValue(customModels); + const client = new CopilotClient({ onListModels: handler }); + + const models = await client.listModels(); + expect(handler).toHaveBeenCalledTimes(1); + expect(models).toEqual(customModels); + }); + }); }); diff --git a/python/copilot/client.py b/python/copilot/client.py index c29f35d12..ff587d997 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -200,6 +200,8 @@ def __init__(self, options: CopilotClientOptions | None = None): if github_token: self.options["github_token"] = github_token + self._on_list_models = opts.get("on_list_models") + self._process: subprocess.Popen | None = None self._client: JsonRpcClient | None = None self._state: ConnectionState = "disconnected" @@ -897,11 +899,15 @@ async def list_models(self) -> list["ModelInfo"]: Results are cached after the first successful call to avoid rate limiting. The cache is cleared when the client disconnects. + If a custom ``on_list_models`` handler was provided in the client options, + it is called instead of querying the CLI server. The handler may be sync + or async. + Returns: A list of ModelInfo objects with model details. Raises: - RuntimeError: If the client is not connected. + RuntimeError: If the client is not connected (when no custom handler is set). Exception: If not authenticated. Example: @@ -909,22 +915,30 @@ async def list_models(self) -> list["ModelInfo"]: >>> for model in models: ... print(f"{model.id}: {model.name}") """ - if not self._client: - raise RuntimeError("Client not connected") - # Use asyncio lock to prevent race condition with concurrent calls async with self._models_cache_lock: # Check cache (already inside lock) if self._models_cache is not None: return list(self._models_cache) # Return a copy to prevent cache mutation - # Cache miss - fetch from backend while holding lock - response = await self._client.request("models.list", {}) - models_data = response.get("models", []) - models = [ModelInfo.from_dict(model) for model in models_data] + if self._on_list_models: + # Use custom handler instead of CLI RPC + result = self._on_list_models() + if inspect.isawaitable(result): + models = await result + else: + models = result + else: + if not self._client: + raise RuntimeError("Client not connected") + + # Cache miss - fetch from backend while holding lock + response = await self._client.request("models.list", {}) + models_data = response.get("models", []) + models = [ModelInfo.from_dict(model) for model in models_data] - # Update cache before releasing lock - self._models_cache = models + # Update cache before releasing lock (copy to prevent external mutation) + self._models_cache = list(models) return list(models) # Return a copy to prevent cache mutation diff --git a/python/copilot/types.py b/python/copilot/types.py index f094666ce..5f4b7e20d 100644 --- a/python/copilot/types.py +++ b/python/copilot/types.py @@ -98,6 +98,11 @@ class CopilotClientOptions(TypedDict, total=False): # When False, only explicit tokens (github_token or environment variables) are used. # Default: True (but defaults to False when github_token is provided) use_logged_in_user: bool + # Custom handler for listing available models. + # When provided, client.list_models() calls this handler instead of + # querying the CLI server. Useful in BYOK mode to return models + # available from your custom provider. + on_list_models: Callable[[], list[ModelInfo] | Awaitable[list[ModelInfo]]] ToolResultType = Literal["success", "failure", "rejected", "denied"] diff --git a/python/test_client.py b/python/test_client.py index ef068b7a1..4a06966d4 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -7,6 +7,7 @@ import pytest from copilot import CopilotClient, PermissionHandler, define_tool +from copilot.types import ModelCapabilities, ModelInfo, ModelLimits, ModelSupports from e2e.testharness import CLI_PATH @@ -214,6 +215,116 @@ def grep(params) -> str: await client.force_stop() +class TestOnListModels: + @pytest.mark.asyncio + async def test_list_models_with_custom_handler(self): + """Test that on_list_models handler is called instead of RPC""" + custom_models = [ + ModelInfo( + id="my-custom-model", + name="My Custom Model", + capabilities=ModelCapabilities( + supports=ModelSupports(vision=False, reasoning_effort=False), + limits=ModelLimits(max_context_window_tokens=128000), + ), + ) + ] + + handler_calls = [] + + def handler(): + handler_calls.append(1) + return custom_models + + client = CopilotClient({"cli_path": CLI_PATH, "on_list_models": handler}) + await client.start() + try: + models = await client.list_models() + assert len(handler_calls) == 1 + assert models == custom_models + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_list_models_handler_caches_results(self): + """Test that on_list_models results are cached""" + custom_models = [ + ModelInfo( + id="cached-model", + name="Cached Model", + capabilities=ModelCapabilities( + supports=ModelSupports(vision=False, reasoning_effort=False), + limits=ModelLimits(max_context_window_tokens=128000), + ), + ) + ] + + handler_calls = [] + + def handler(): + handler_calls.append(1) + return custom_models + + client = CopilotClient({"cli_path": CLI_PATH, "on_list_models": handler}) + await client.start() + try: + await client.list_models() + await client.list_models() + assert len(handler_calls) == 1 # Only called once due to caching + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_list_models_async_handler(self): + """Test that async on_list_models handler works""" + custom_models = [ + ModelInfo( + id="async-model", + name="Async Model", + capabilities=ModelCapabilities( + supports=ModelSupports(vision=False, reasoning_effort=False), + limits=ModelLimits(max_context_window_tokens=128000), + ), + ) + ] + + async def handler(): + return custom_models + + client = CopilotClient({"cli_path": CLI_PATH, "on_list_models": handler}) + await client.start() + try: + models = await client.list_models() + assert models == custom_models + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_list_models_handler_without_start(self): + """Test that on_list_models works without starting the CLI connection""" + custom_models = [ + ModelInfo( + id="no-start-model", + name="No Start Model", + capabilities=ModelCapabilities( + supports=ModelSupports(vision=False, reasoning_effort=False), + limits=ModelLimits(max_context_window_tokens=128000), + ), + ) + ] + + handler_calls = [] + + def handler(): + handler_calls.append(1) + return custom_models + + client = CopilotClient({"cli_path": CLI_PATH, "on_list_models": handler}) + models = await client.list_models() + assert len(handler_calls) == 1 + assert models == custom_models + + class TestSessionConfigForwarding: @pytest.mark.asyncio async def test_create_session_forwards_client_name(self):