diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 363aaf1d..670352d6 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -4,14 +4,12 @@ import ( "bufio" "bytes" "context" - _ "embed" "encoding/json" "fmt" "io" "net" "net/http" "net/http/httptest" - "slices" "strings" "sync" "sync/atomic" @@ -26,6 +24,8 @@ import ( "github.com/coder/aibridge" "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" @@ -44,35 +44,7 @@ import ( "golang.org/x/tools/txtar" ) -var ( - //go:embed fixtures/anthropic/simple.txtar - antSimple []byte - //go:embed fixtures/anthropic/single_builtin_tool.txtar - antSingleBuiltinTool []byte - //go:embed fixtures/anthropic/single_injected_tool.txtar - antSingleInjectedTool []byte - //go:embed fixtures/anthropic/fallthrough.txtar - antFallthrough []byte - //go:embed fixtures/anthropic/stream_error.txtar - antMidStreamErr []byte - //go:embed fixtures/anthropic/non_stream_error.txtar - antNonStreamErr []byte - - //go:embed fixtures/openai/chatcompletions/simple.txtar - oaiSimple []byte - //go:embed fixtures/openai/chatcompletions/single_builtin_tool.txtar - oaiSingleBuiltinTool []byte - //go:embed fixtures/openai/chatcompletions/single_injected_tool.txtar - oaiSingleInjectedTool []byte - //go:embed fixtures/openai/chatcompletions/fallthrough.txtar - oaiFallthrough []byte - //go:embed fixtures/openai/chatcompletions/stream_error.txtar - oaiMidStreamErr []byte - //go:embed fixtures/openai/chatcompletions/non_stream_error.txtar - oaiNonStreamErr []byte - - testTracer = otel.Tracer("forTesting") -) +var testTracer = otel.Tracer("forTesting") const ( fixtureRequest = "request" @@ -117,7 +89,7 @@ func TestAnthropicMessages(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { t.Parallel() - arc := txtar.Parse(antSingleBuiltinTool) + arc := txtar.Parse(fixtures.AntSingleBuiltinTool) t.Logf("%s: %s", t.Name(), arc.Comment) files := filesMap(arc) @@ -138,7 +110,7 @@ func TestAnthropicMessages(t *testing.T) { srv := newMockServer(ctx, t, files, nil) t.Cleanup(srv.Close) - recorderClient := &mockRecorderClient{} + recorderClient := &testutil.MockRecorder{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(srv.URL, apiKey), nil)} @@ -194,7 +166,7 @@ func TestAnthropicMessages(t *testing.T) { require.Len(t, promptUsages, 1) assert.Equal(t, "read the foo file", promptUsages[0].Prompt) - recorderClient.verifyAllInterceptionsEnded(t) + recorderClient.VerifyAllInterceptionsEnded(t) }) } }) @@ -206,7 +178,7 @@ func TestAWSBedrockIntegration(t *testing.T) { t.Run("invalid config", func(t *testing.T) { t.Parallel() - arc := txtar.Parse(antSingleBuiltinTool) + arc := txtar.Parse(fixtures.AntSingleBuiltinTool) files := filesMap(arc) reqBody := files[fixtureRequest] @@ -222,7 +194,7 @@ func TestAWSBedrockIntegration(t *testing.T) { SmallFastModel: "test-haiku", } - recorderClient := &mockRecorderClient{} + recorderClient := &testutil.MockRecorder{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{ provider.NewAnthropic(anthropicCfg("http://unused", apiKey), bedrockCfg), @@ -253,7 +225,7 @@ func TestAWSBedrockIntegration(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), streaming), func(t *testing.T) { t.Parallel() - arc := txtar.Parse(antSingleBuiltinTool) + arc := txtar.Parse(fixtures.AntSingleBuiltinTool) t.Logf("%s: %s", t.Name(), arc.Comment) files := filesMap(arc) @@ -319,7 +291,7 @@ func TestAWSBedrockIntegration(t *testing.T) { EndpointOverride: srv.URL, } - recorderClient := &mockRecorderClient{} + recorderClient := &testutil.MockRecorder{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge( @@ -356,7 +328,7 @@ func TestAWSBedrockIntegration(t *testing.T) { interceptions := recorderClient.RecordedInterceptions() require.Len(t, interceptions, 1) require.Equal(t, interceptions[0].Model, bedrockCfg.Model) - recorderClient.verifyAllInterceptionsEnded(t) + recorderClient.VerifyAllInterceptionsEnded(t) }) } }) @@ -388,7 +360,7 @@ func TestOpenAIChatCompletions(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { t.Parallel() - arc := txtar.Parse(oaiSingleBuiltinTool) + arc := txtar.Parse(fixtures.OaiChatSingleBuiltinTool) t.Logf("%s: %s", t.Name(), arc.Comment) files := filesMap(arc) @@ -409,7 +381,7 @@ func TestOpenAIChatCompletions(t *testing.T) { srv := newMockServer(ctx, t, files, nil) t.Cleanup(srv.Close) - recorderClient := &mockRecorderClient{} + recorderClient := &testutil.MockRecorder{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(srv.URL, apiKey))} @@ -461,7 +433,7 @@ func TestOpenAIChatCompletions(t *testing.T) { require.Len(t, promptUsages, 1) assert.Equal(t, "how large is the README.md file in my current path", promptUsages[0].Prompt) - recorderClient.verifyAllInterceptionsEnded(t) + recorderClient.VerifyAllInterceptionsEnded(t) }) } }) @@ -480,7 +452,7 @@ func TestSimple(t *testing.T) { }{ { name: config.ProviderAnthropic, - fixture: antSimple, + fixture: fixtures.AntSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} @@ -519,7 +491,7 @@ func TestSimple(t *testing.T) { }, { name: config.ProviderOpenAI, - fixture: oaiSimple, + fixture: fixtures.OaiChatSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} @@ -588,7 +560,7 @@ func TestSimple(t *testing.T) { srv := newMockServer(ctx, t, files, nil) t.Cleanup(srv.Close) - recorderClient := &mockRecorderClient{} + recorderClient := &testutil.MockRecorder{} b, err := tc.configureFunc(srv.URL, recorderClient) require.NoError(t, err) @@ -631,7 +603,7 @@ func TestSimple(t *testing.T) { require.GreaterOrEqual(t, len(tokenUsages), 1) require.Equal(t, tokenUsages[0].MsgID, tc.expectedMsgID) - recorderClient.verifyAllInterceptionsEnded(t) + recorderClient.VerifyAllInterceptionsEnded(t) }) } }) @@ -653,7 +625,7 @@ func TestFallthrough(t *testing.T) { }{ { name: config.ProviderAnthropic, - fixture: antFallthrough, + fixture: fixtures.AntFallthrough, configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := provider.NewAnthropic(anthropicCfg(addr, apiKey), nil) @@ -664,7 +636,7 @@ func TestFallthrough(t *testing.T) { }, { name: config.ProviderOpenAI, - fixture: oaiFallthrough, + fixture: fixtures.OaiChatFallthrough, configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) provider := provider.NewOpenAI(openaiCfg(addr, apiKey)) @@ -701,7 +673,7 @@ func TestFallthrough(t *testing.T) { })) t.Cleanup(upstream.Close) - recorderClient := &mockRecorderClient{} + recorderClient := &testutil.MockRecorder{} provider, bridge := tc.configureFunc(upstream.URL, recorderClient) @@ -780,7 +752,7 @@ func TestAnthropicInjectedTools(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq) + recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq) // Ensure expected tool was invoked with expected input. toolUsages := recorderClient.RecordedToolUsages() @@ -870,7 +842,7 @@ func TestOpenAIInjectedTools(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq) + recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq) // Ensure expected tool was invoked with expected input. toolUsages := recorderClient.RecordedToolUsages() @@ -963,7 +935,7 @@ func TestOpenAIInjectedTools(t *testing.T) { // setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests. // Kinda fugly right now, we can refactor this later. -func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *callAccumulator, map[string]mcp.ServerProxier, *http.Response) { +func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*testutil.MockRecorder, *callAccumulator, map[string]mcp.ServerProxier, *http.Response) { t.Helper() arc := txtar.Parse(fixture) @@ -1006,7 +978,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu }) t.Cleanup(mockSrv.Close) - recorderClient := &mockRecorderClient{} + recorderClient := &testutil.MockRecorder{} // Setup MCP mcpProxiers. mcpProxiers, acc := setupMCPServerProxiesForTest(t, testTracer) @@ -1056,7 +1028,7 @@ func TestErrorHandling(t *testing.T) { }{ { name: config.ProviderAnthropic, - fixture: antNonStreamErr, + fixture: fixtures.AntNonStreamError, createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) @@ -1074,7 +1046,7 @@ func TestErrorHandling(t *testing.T) { }, { name: config.ProviderOpenAI, - fixture: oaiNonStreamErr, + fixture: fixtures.OaiChatNonStreamError, createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) @@ -1126,7 +1098,7 @@ func TestErrorHandling(t *testing.T) { mockSrv := newMockHTTPReflector(ctx, t, mockResp) t.Cleanup(mockSrv.Close) - recorderClient := &mockRecorderClient{} + recorderClient := &testutil.MockRecorder{} b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, testTracer)) require.NoError(t, err) @@ -1145,7 +1117,7 @@ func TestErrorHandling(t *testing.T) { require.NoError(t, err) tc.responseHandlerFn(resp) - recorderClient.verifyAllInterceptionsEnded(t) + recorderClient.VerifyAllInterceptionsEnded(t) }) } }) @@ -1163,7 +1135,7 @@ func TestErrorHandling(t *testing.T) { }{ { name: config.ProviderAnthropic, - fixture: antMidStreamErr, + fixture: fixtures.AntMidStreamError, createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) @@ -1182,7 +1154,7 @@ func TestErrorHandling(t *testing.T) { }, { name: config.ProviderOpenAI, - fixture: oaiMidStreamErr, + fixture: fixtures.OaiChatMidStreamError, createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) @@ -1228,7 +1200,7 @@ func TestErrorHandling(t *testing.T) { mockSrv.statusCode = http.StatusInternalServerError t.Cleanup(mockSrv.Close) - recorderClient := &mockRecorderClient{} + recorderClient := &testutil.MockRecorder{} b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, testTracer)) require.NoError(t, err) @@ -1248,7 +1220,7 @@ func TestErrorHandling(t *testing.T) { bridgeSrv.Close() tc.responseHandlerFn(resp) - recorderClient.verifyAllInterceptionsEnded(t) + recorderClient.VerifyAllInterceptionsEnded(t) }) } }) @@ -1271,7 +1243,7 @@ func TestStableRequestEncoding(t *testing.T) { }{ { name: config.ProviderAnthropic, - fixture: antSimple, + fixture: fixtures.AntSimple, createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} @@ -1280,7 +1252,7 @@ func TestStableRequestEncoding(t *testing.T) { }, { name: config.ProviderOpenAI, - fixture: oaiSimple, + fixture: fixtures.OaiChatSimple, createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} @@ -1344,7 +1316,7 @@ func TestStableRequestEncoding(t *testing.T) { mockSrv.Start() t.Cleanup(mockSrv.Close) - recorder := &mockRecorderClient{} + recorder := &testutil.MockRecorder{} bridge, err := tc.configureFunc(mockSrv.URL, recorder, mcpMgr) require.NoError(t, err) @@ -1462,7 +1434,7 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { } require.NoError(t, mcpMgr.Init(ctx)) - arc := txtar.Parse(antSimple) + arc := txtar.Parse(fixtures.AntSimple) files := filesMap(arc) require.Contains(t, files, fixtureRequest) require.Contains(t, files, fixtureNonStreamingResponse) @@ -1498,7 +1470,7 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { mockSrv.Start() t.Cleanup(mockSrv.Close) - recorder := &mockRecorderClient{} + recorder := &testutil.MockRecorder{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(mockSrv.URL, apiKey), nil)} bridge, err := aibridge.NewRequestBridge(ctx, providers, recorder, mcpMgr, logger, nil, testTracer) @@ -1561,7 +1533,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { }{ { name: config.ProviderAnthropic, - fixture: antSimple, + fixture: fixtures.AntSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(addr, apiKey), nil)} @@ -1575,7 +1547,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { }, { name: config.ProviderOpenAI, - fixture: oaiSimple, + fixture: fixtures.OaiChatSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(addr, apiKey))} @@ -1620,7 +1592,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Setenv(key, val) } - recorderClient := &mockRecorderClient{} + recorderClient := &testutil.MockRecorder{} b, err := tc.configureFunc(srv.URL, recorderClient) require.NoError(t, err) @@ -1816,105 +1788,6 @@ func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, resp return ms } -var _ aibridge.Recorder = &mockRecorderClient{} - -type mockRecorderClient struct { - mu sync.Mutex - - interceptions []*recorder.InterceptionRecord - tokenUsages []*recorder.TokenUsageRecord - userPrompts []*recorder.PromptUsageRecord - toolUsages []*recorder.ToolUsageRecord - interceptionsEnd map[string]time.Time -} - -func (m *mockRecorderClient) RecordInterception(ctx context.Context, req *recorder.InterceptionRecord) error { - m.mu.Lock() - defer m.mu.Unlock() - m.interceptions = append(m.interceptions, req) - return nil -} - -func (m *mockRecorderClient) RecordInterceptionEnded(ctx context.Context, req *recorder.InterceptionRecordEnded) error { - m.mu.Lock() - defer m.mu.Unlock() - if m.interceptionsEnd == nil { - m.interceptionsEnd = make(map[string]time.Time) - } - if !slices.ContainsFunc(m.interceptions, func(intc *recorder.InterceptionRecord) bool { return intc.ID == req.ID }) { - return fmt.Errorf("id not found") - } - m.interceptionsEnd[req.ID] = req.EndedAt - return nil -} - -func (m *mockRecorderClient) RecordPromptUsage(ctx context.Context, req *recorder.PromptUsageRecord) error { - m.mu.Lock() - defer m.mu.Unlock() - m.userPrompts = append(m.userPrompts, req) - return nil -} - -func (m *mockRecorderClient) RecordTokenUsage(ctx context.Context, req *recorder.TokenUsageRecord) error { - m.mu.Lock() - defer m.mu.Unlock() - m.tokenUsages = append(m.tokenUsages, req) - return nil -} - -func (m *mockRecorderClient) RecordToolUsage(ctx context.Context, req *recorder.ToolUsageRecord) error { - m.mu.Lock() - defer m.mu.Unlock() - m.toolUsages = append(m.toolUsages, req) - return nil -} - -// RecordedTokenUsages returns a copy of recorded token usages in a thread-safe manner. -// Note: This is a shallow clone - the slice is copied but the pointers reference the -// same underlying records. This is sufficient for our test assertions which only read -// the data and don't modify the records. -func (m *mockRecorderClient) RecordedTokenUsages() []*recorder.TokenUsageRecord { - m.mu.Lock() - defer m.mu.Unlock() - return slices.Clone(m.tokenUsages) -} - -// RecordedPromptUsages returns a copy of recorded prompt usages in a thread-safe manner. -// Note: This is a shallow clone (see RecordedTokenUsages for details). -func (m *mockRecorderClient) RecordedPromptUsages() []*recorder.PromptUsageRecord { - m.mu.Lock() - defer m.mu.Unlock() - return slices.Clone(m.userPrompts) -} - -// RecordedToolUsages returns a copy of recorded tool usages in a thread-safe manner. -// Note: This is a shallow clone (see RecordedTokenUsages for details). -func (m *mockRecorderClient) RecordedToolUsages() []*recorder.ToolUsageRecord { - m.mu.Lock() - defer m.mu.Unlock() - return slices.Clone(m.toolUsages) -} - -// RecordedInterceptions returns a copy of recorded interceptions in a thread-safe manner. -// Note: This is a shallow clone (see RecordedTokenUsages for details). -func (m *mockRecorderClient) RecordedInterceptions() []*recorder.InterceptionRecord { - m.mu.Lock() - defer m.mu.Unlock() - return slices.Clone(m.interceptions) -} - -// verify all recorded interceptions has been marked as completed -func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) { - t.Helper() - - m.mu.Lock() - defer m.mu.Unlock() - require.Equalf(t, len(m.interceptions), len(m.interceptionsEnd), "got %v interception ended calls, want: %v", len(m.interceptionsEnd), len(m.interceptions)) - for _, intc := range m.interceptions { - require.Containsf(t, m.interceptionsEnd, intc.ID, "interception with id: %v has not been ended", intc.ID) - } -} - const mockToolName = "coder_list_workspaces" // callAccumulator tracks all tool invocations by name and each instance's arguments. diff --git a/fixtures/fixtures.go b/fixtures/fixtures.go new file mode 100644 index 00000000..7272bb73 --- /dev/null +++ b/fixtures/fixtures.go @@ -0,0 +1,85 @@ +package fixtures + +import ( + _ "embed" +) + +var ( + //go:embed anthropic/simple.txtar + AntSimple []byte + + //go:embed anthropic/single_builtin_tool.txtar + AntSingleBuiltinTool []byte + + //go:embed anthropic/single_injected_tool.txtar + AntSingleInjectedTool []byte + + //go:embed anthropic/fallthrough.txtar + AntFallthrough []byte + + //go:embed anthropic/stream_error.txtar + AntMidStreamError []byte + + //go:embed anthropic/non_stream_error.txtar + AntNonStreamError []byte +) + +var ( + //go:embed openai/chatcompletions/simple.txtar + OaiChatSimple []byte + + //go:embed openai/chatcompletions/single_builtin_tool.txtar + OaiChatSingleBuiltinTool []byte + + //go:embed openai/chatcompletions/single_injected_tool.txtar + OaiChatSingleInjectedTool []byte + + //go:embed openai/chatcompletions/fallthrough.txtar + OaiChatFallthrough []byte + + //go:embed openai/chatcompletions/stream_error.txtar + OaiChatMidStreamError []byte + + //go:embed openai/chatcompletions/non_stream_error.txtar + OaiChatNonStreamError []byte +) + +var ( + //go:embed openai/responses/blocking/simple.txtar + OaiResponsesBlockingSimple []byte + + //go:embed openai/responses/blocking/builtin_tool.txtar + OaiResponsesBlockingBuiltinTool []byte + + //go:embed openai/responses/blocking/conversation.txtar + OaiResponsesBlockingConversation []byte + + //go:embed openai/responses/blocking/prev_response_id.txtar + OaiResponsesBlockingPrevResponseID []byte + + //go:embed openai/responses/blocking/wrong_response_format.txtar + OaiResponsesBlockingWrongResponseFormat []byte +) + +var ( + //go:embed openai/responses/streaming/simple.txtar + OaiResponsesStreamingSimple []byte + + //go:embed openai/responses/streaming/builtin_tool.txtar + OaiResponsesStreamingBuiltinTool []byte + + //go:embed openai/responses/streaming/conversation.txtar + OaiResponsesStreamingConversation []byte + + //go:embed openai/responses/streaming/prev_response_id.txtar + OaiResponsesStreamingPrevResponseID []byte + + //go:embed openai/responses/streaming/stream_error.txtar + OaiResponsesStreamingStreamError []byte + + //go:embed openai/responses/streaming/stream_failure.txtar + OaiResponsesStreamingStreamFailure []byte + + //go:embed openai/responses/streaming/wrong_response_format.txtar + OaiResponsesStreamingWrongResponseFormat []byte +) diff --git a/internal/testutil/mock_recorder.go b/internal/testutil/mock_recorder.go new file mode 100644 index 00000000..d4c9c8e0 --- /dev/null +++ b/internal/testutil/mock_recorder.go @@ -0,0 +1,120 @@ +package testutil + +import ( + "context" + "fmt" + "slices" + "sync" + "testing" + "time" + + "github.com/coder/aibridge/recorder" + "github.com/stretchr/testify/require" +) + +// MockRecorder is a test implementation of aibridge.Recorder that +// captures all recording calls for test assertions. +type MockRecorder struct { + mu sync.Mutex + + interceptions []*recorder.InterceptionRecord + tokenUsages []*recorder.TokenUsageRecord + userPrompts []*recorder.PromptUsageRecord + toolUsages []*recorder.ToolUsageRecord + interceptionsEnd map[string]time.Time +} + +func (m *MockRecorder) RecordInterception(ctx context.Context, req *recorder.InterceptionRecord) error { + m.mu.Lock() + defer m.mu.Unlock() + m.interceptions = append(m.interceptions, req) + return nil +} + +func (m *MockRecorder) RecordInterceptionEnded(ctx context.Context, req *recorder.InterceptionRecordEnded) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.interceptionsEnd == nil { + m.interceptionsEnd = make(map[string]time.Time) + } + if !slices.ContainsFunc(m.interceptions, func(intc *recorder.InterceptionRecord) bool { return intc.ID == req.ID }) { + return fmt.Errorf("id not found") + } + m.interceptionsEnd[req.ID] = req.EndedAt + return nil +} + +func (m *MockRecorder) RecordPromptUsage(ctx context.Context, req *recorder.PromptUsageRecord) error { + m.mu.Lock() + defer m.mu.Unlock() + m.userPrompts = append(m.userPrompts, req) + return nil +} + +func (m *MockRecorder) RecordTokenUsage(ctx context.Context, req *recorder.TokenUsageRecord) error { + m.mu.Lock() + defer m.mu.Unlock() + m.tokenUsages = append(m.tokenUsages, req) + return nil +} + +func (m *MockRecorder) RecordToolUsage(ctx context.Context, req *recorder.ToolUsageRecord) error { + m.mu.Lock() + defer m.mu.Unlock() + m.toolUsages = append(m.toolUsages, req) + return nil +} + +// RecordedTokenUsages returns a copy of recorded token usages in a thread-safe manner. +// Note: This is a shallow clone - the slice is copied but the pointers reference the +// same underlying records. This is sufficient for our test assertions which only read +// the data and don't modify the records. +func (m *MockRecorder) RecordedTokenUsages() []*recorder.TokenUsageRecord { + m.mu.Lock() + defer m.mu.Unlock() + return slices.Clone(m.tokenUsages) +} + +// RecordedPromptUsages returns a copy of recorded prompt usages in a thread-safe manner. +// Note: This is a shallow clone (see RecordedTokenUsages for details). +func (m *MockRecorder) RecordedPromptUsages() []*recorder.PromptUsageRecord { + m.mu.Lock() + defer m.mu.Unlock() + return slices.Clone(m.userPrompts) +} + +// RecordedToolUsages returns a copy of recorded tool usages in a thread-safe manner. +// Note: This is a shallow clone (see RecordedTokenUsages for details). +func (m *MockRecorder) RecordedToolUsages() []*recorder.ToolUsageRecord { + m.mu.Lock() + defer m.mu.Unlock() + return slices.Clone(m.toolUsages) +} + +// RecordedInterceptions returns a copy of recorded interceptions in a thread-safe manner. +// Note: This is a shallow clone (see RecordedTokenUsages for details). +func (m *MockRecorder) RecordedInterceptions() []*recorder.InterceptionRecord { + m.mu.Lock() + defer m.mu.Unlock() + return slices.Clone(m.interceptions) +} + +// ToolUsages returns the raw toolUsages slice for direct field access in tests. +// Use RecordedToolUsages() for thread-safe access when assertions don't need direct field access. +func (m *MockRecorder) ToolUsages() []*recorder.ToolUsageRecord { + m.mu.Lock() + defer m.mu.Unlock() + return m.toolUsages +} + +// VerifyAllInterceptionsEnded verifies all recorded interceptions have been marked as completed. +func (m *MockRecorder) VerifyAllInterceptionsEnded(t *testing.T) { + t.Helper() + + m.mu.Lock() + defer m.mu.Unlock() + require.Equalf(t, len(m.interceptions), len(m.interceptionsEnd), "got %v interception ended calls, want: %v", len(m.interceptionsEnd), len(m.interceptions)) + for _, intc := range m.interceptions { + require.Containsf(t, m.interceptionsEnd, intc.ID, "interception with id: %v has not been ended", intc.ID) + } +} diff --git a/metrics_integration_test.go b/metrics_integration_test.go index ba8742fb..d863927c 100644 --- a/metrics_integration_test.go +++ b/metrics_integration_test.go @@ -14,6 +14,8 @@ import ( "github.com/coder/aibridge" "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/provider" @@ -32,11 +34,11 @@ func TestMetrics_Interception(t *testing.T) { expectedStatus string }{ { - fixture: antSimple, + fixture: fixtures.AntSimple, expectedStatus: metrics.InterceptionCountStatusCompleted, }, { - fixture: antNonStreamErr, + fixture: fixtures.AntNonStreamError, expectedStatus: metrics.InterceptionCountStatusFailed, }, } @@ -71,7 +73,7 @@ func TestMetrics_Interception(t *testing.T) { func TestMetrics_InterceptionsInflight(t *testing.T) { t.Parallel() - arc := txtar.Parse(antSimple) + arc := txtar.Parse(fixtures.AntSimple) files := filesMap(arc) ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) @@ -134,7 +136,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { func TestMetrics_PassthroughCount(t *testing.T) { t.Parallel() - arc := txtar.Parse(oaiFallthrough) + arc := txtar.Parse(fixtures.OaiChatFallthrough) files := filesMap(arc) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -164,7 +166,7 @@ func TestMetrics_PassthroughCount(t *testing.T) { func TestMetrics_PromptCount(t *testing.T) { t.Parallel() - arc := txtar.Parse(oaiSimple) + arc := txtar.Parse(fixtures.OaiChatSimple) files := filesMap(arc) ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) @@ -192,7 +194,7 @@ func TestMetrics_PromptCount(t *testing.T) { func TestMetrics_NonInjectedToolUseCount(t *testing.T) { t.Parallel() - arc := txtar.Parse(oaiSingleBuiltinTool) + arc := txtar.Parse(fixtures.OaiChatSingleBuiltinTool) files := filesMap(arc) ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) @@ -220,7 +222,7 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { func TestMetrics_InjectedToolUseCount(t *testing.T) { t.Parallel() - arc := txtar.Parse(antSingleInjectedTool) + arc := txtar.Parse(fixtures.AntSingleInjectedTool) files := filesMap(arc) ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) @@ -235,7 +237,7 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { }) t.Cleanup(mockAPI.Close) - recorder := &mockRecorderClient{} + recorder := &testutil.MockRecorder{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) metrics := aibridge.NewMetrics(prometheus.NewRegistry()) provider := provider.NewAnthropic(anthropicCfg(mockAPI.URL, apiKey), nil) @@ -267,21 +269,21 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { return mockAPI.callCount.Load() == 2 }, time.Second*10, time.Millisecond*50) - require.Len(t, recorder.toolUsages, 1) - require.True(t, recorder.toolUsages[0].Injected) - require.NotNil(t, recorder.toolUsages[0].ServerURL) - actualServerURL := *recorder.toolUsages[0].ServerURL + require.Len(t, recorder.ToolUsages(), 1) + require.True(t, recorder.ToolUsages()[0].Injected) + require.NotNil(t, recorder.ToolUsages()[0].ServerURL) + actualServerURL := *recorder.ToolUsages()[0].ServerURL count := promtest.ToFloat64(metrics.InjectedToolUseCount.WithLabelValues( config.ProviderAnthropic, "claude-sonnet-4-20250514", actualServerURL, mockToolName)) require.Equal(t, 1.0, count) } -func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, metrics *metrics.Metrics, tracer trace.Tracer) (*httptest.Server, *mockRecorderClient) { +func newTestSrv(t *testing.T, ctx context.Context, provider aibridge.Provider, metrics *metrics.Metrics, tracer trace.Tracer) (*httptest.Server, *testutil.MockRecorder) { t.Helper() logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - mockRecorder := &mockRecorderClient{} + mockRecorder := &testutil.MockRecorder{} clientFn := func() (aibridge.Recorder, error) { return mockRecorder, nil } diff --git a/responses_integration_test.go b/responses_integration_test.go index bf9b6f81..ae6bdf94 100644 --- a/responses_integration_test.go +++ b/responses_integration_test.go @@ -3,7 +3,6 @@ package aibridge_test import ( "bytes" "context" - _ "embed" "encoding/json" "io" "net" @@ -13,50 +12,13 @@ import ( "testing" "time" + "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/provider" "github.com/openai/openai-go/v3/responses" "github.com/stretchr/testify/require" "golang.org/x/tools/txtar" ) -var ( - //go:embed fixtures/openai/responses/blocking/simple.txtar - fixtResponsesBlockingSimple []byte - - //go:embed fixtures/openai/responses/blocking/builtin_tool.txtar - fixtResponsesBlockingBuiltinTool []byte - - //go:embed fixtures/openai/responses/blocking/conversation.txtar - fixtResponsesBlockingConversation []byte - - //go:embed fixtures/openai/responses/blocking/prev_response_id.txtar - fixtResponsesBlockingPrevResponseID []byte - - //go:embed fixtures/openai/responses/blocking/wrong_response_format.txtar - fixtResponsesBlockingWrongResponseFormat []byte - - //go:embed fixtures/openai/responses/streaming/simple.txtar - fixtResponsesStreamingSimple []byte - - //go:embed fixtures/openai/responses/streaming/builtin_tool.txtar - fixtResponsesStreamingBuiltinTool []byte - - //go:embed fixtures/openai/responses/streaming/conversation.txtar - fixtResponsesStreamingConversation []byte - - //go:embed fixtures/openai/responses/streaming/prev_response_id.txtar - fixtResponsesStreamingPrevResponseID []byte - - //go:embed fixtures/openai/responses/streaming/stream_error.txtar - fixtResponsesStreamingStreamError []byte - - //go:embed fixtures/openai/responses/streaming/stream_failure.txtar - fixtResponsesStreamingStreamFailure []byte - - //go:embed fixtures/openai/responses/streaming/wrong_response_format.txtar - fixtResponsesStreamingWrongResponseFormat []byte -) - type keyVal struct { key string val any @@ -72,60 +34,59 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { }{ { name: "blocking_simple", - fixture: fixtResponsesBlockingSimple, + fixture: fixtures.OaiResponsesBlockingSimple, }, { name: "blocking_builtin_tool", - fixture: fixtResponsesBlockingBuiltinTool, + fixture: fixtures.OaiResponsesBlockingBuiltinTool, }, { name: "blocking_conversation", - fixture: fixtResponsesBlockingConversation, + fixture: fixtures.OaiResponsesBlockingConversation, }, { name: "blocking_prev_response_id", - fixture: fixtResponsesBlockingPrevResponseID, + fixture: fixtures.OaiResponsesBlockingPrevResponseID, }, - { name: "streaming_simple", - fixture: fixtResponsesStreamingSimple, + fixture: fixtures.OaiResponsesStreamingSimple, streaming: true, }, { name: "streaming_builtin_tool", - fixture: fixtResponsesStreamingBuiltinTool, + fixture: fixtures.OaiResponsesStreamingBuiltinTool, streaming: true, }, { name: "streaming_conversation", - fixture: fixtResponsesStreamingConversation, + fixture: fixtures.OaiResponsesStreamingConversation, streaming: true, }, { name: "streaming_prev_response_id", - fixture: fixtResponsesStreamingPrevResponseID, + fixture: fixtures.OaiResponsesStreamingPrevResponseID, streaming: true, }, { name: "stream_error", - fixture: fixtResponsesStreamingStreamError, + fixture: fixtures.OaiResponsesStreamingStreamError, streaming: true, }, { name: "stream_failure", - fixture: fixtResponsesStreamingStreamFailure, + fixture: fixtures.OaiResponsesStreamingStreamFailure, streaming: true, }, - // Even when response has wrong json format original response status code, body is kept as is + // Original status code and body is kept even with wrong json format { name: "blocking_wrong_format", - fixture: fixtResponsesBlockingWrongResponseFormat, + fixture: fixtures.OaiResponsesBlockingWrongResponseFormat, }, { name: "streaming_wrong_format", - fixture: fixtResponsesStreamingWrongResponseFormat, + fixture: fixtures.OaiResponsesStreamingWrongResponseFormat, streaming: true, }, } diff --git a/trace_integration_test.go b/trace_integration_test.go index d7b339cf..f3771f3a 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -14,6 +14,7 @@ import ( "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/aibridge" "github.com/coder/aibridge/config" + "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/tracing" @@ -87,7 +88,7 @@ func TestTraceAnthropic(t *testing.T) { }, } - arc := txtar.Parse(antSingleBuiltinTool) + arc := txtar.Parse(fixtures.AntSingleBuiltinTool) files := filesMap(arc) require.Contains(t, files, fixtureRequest) @@ -127,8 +128,8 @@ func TestTraceAnthropic(t *testing.T) { defer resp.Body.Close() srv.Close() - require.Equal(t, 1, len(recorder.interceptions)) - intcID := recorder.interceptions[0].ID + require.Equal(t, 1, len(recorder.RecordedInterceptions())) + intcID := recorder.RecordedInterceptions()[0].ID model := gjson.Get(string(reqBody), "model").Str if tc.bedrock { @@ -212,9 +213,9 @@ func TestTraceAnthropicErr(t *testing.T) { var arc *txtar.Archive if tc.streaming { - arc = txtar.Parse(antMidStreamErr) + arc = txtar.Parse(fixtures.AntMidStreamError) } else { - arc = txtar.Parse(antNonStreamErr) + arc = txtar.Parse(fixtures.AntNonStreamError) } files := filesMap(arc) @@ -257,8 +258,8 @@ func TestTraceAnthropicErr(t *testing.T) { defer resp.Body.Close() srv.Close() - require.Equal(t, 1, len(recorder.interceptions)) - intcID := recorder.interceptions[0].ID + require.Equal(t, 1, len(recorder.RecordedInterceptions())) + intcID := recorder.RecordedInterceptions()[0].ID totalCount := 0 for _, e := range tc.expect { @@ -348,12 +349,12 @@ func TestAnthropicInjectedToolsTrace(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, _, proxies, resp := setupInjectedToolTest(t, antSingleInjectedTool, tc.streaming, configureFn, reqFunc) + recorderClient, _, proxies, resp := setupInjectedToolTest(t, fixtures.AntSingleInjectedTool, tc.streaming, configureFn, reqFunc) defer resp.Body.Close() - require.Len(t, recorderClient.interceptions, 1) - intcID := recorderClient.interceptions[0].ID + require.Len(t, recorderClient.RecordedInterceptions(), 1) + intcID := recorderClient.RecordedInterceptions()[0].ID model := gjson.Get(string(reqBody), "model").Str if tc.bedrock { @@ -392,7 +393,7 @@ func TestTraceOpenAI(t *testing.T) { }{ { name: "trace_openai_streaming", - fixture: oaiSimple, + fixture: fixtures.OaiChatSimple, streaming: true, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, @@ -407,7 +408,7 @@ func TestTraceOpenAI(t *testing.T) { }, { name: "trace_openai_non_streaming", - fixture: oaiSimple, + fixture: fixtures.OaiChatSimple, streaming: false, expect: []expectTrace{ {"Intercept", 1, codes.Unset}, @@ -457,8 +458,8 @@ func TestTraceOpenAI(t *testing.T) { defer resp.Body.Close() srv.Close() - require.Equal(t, 1, len(recorder.interceptions)) - intcID := recorder.interceptions[0].ID + require.Equal(t, 1, len(recorder.RecordedInterceptions())) + intcID := recorder.RecordedInterceptions()[0].ID totalCount := 0 for _, e := range tc.expect { @@ -519,9 +520,9 @@ func TestTraceOpenAIErr(t *testing.T) { var arc *txtar.Archive if tc.streaming { - arc = txtar.Parse(oaiMidStreamErr) + arc = txtar.Parse(fixtures.OaiChatMidStreamError) } else { - arc = txtar.Parse(oaiNonStreamErr) + arc = txtar.Parse(fixtures.OaiChatNonStreamError) } files := filesMap(arc) @@ -559,8 +560,8 @@ func TestTraceOpenAIErr(t *testing.T) { defer resp.Body.Close() srv.Close() - require.Equal(t, 1, len(recorder.interceptions)) - intcID := recorder.interceptions[0].ID + require.Equal(t, 1, len(recorder.RecordedInterceptions())) + intcID := recorder.RecordedInterceptions()[0].ID totalCount := 0 for _, e := range tc.expect { @@ -609,12 +610,12 @@ func TestOpenAIInjectedToolsTrace(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, _, proxies, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, reqFunc) + recorderClient, _, proxies, resp := setupInjectedToolTest(t, fixtures.OaiChatSingleInjectedTool, streaming, configureFn, reqFunc) defer resp.Body.Close() - require.Len(t, recorderClient.interceptions, 1) - intcID := recorderClient.interceptions[0].ID + require.Len(t, recorderClient.RecordedInterceptions(), 1) + intcID := recorderClient.RecordedInterceptions()[0].ID for _, proxy := range proxies { require.NotEmpty(t, proxy.ListTools()) @@ -641,7 +642,7 @@ func TestOpenAIInjectedToolsTrace(t *testing.T) { func TestTracePassthrough(t *testing.T) { t.Parallel() - arc := txtar.Parse(oaiFallthrough) + arc := txtar.Parse(fixtures.OaiChatFallthrough) files := filesMap(arc) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {