Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions fixtures/fixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package fixtures

import (
_ "embed"
"testing"

"golang.org/x/tools/txtar"
)

var (
Expand Down Expand Up @@ -65,6 +68,9 @@ var (
//go:embed openai/responses/streaming/simple.txtar
OaiResponsesStreamingSimple []byte

//go:embed openai/responses/streaming/codex_example.txtar
OaiResponsesStreamingCodex []byte

//go:embed openai/responses/streaming/builtin_tool.txtar
OaiResponsesStreamingBuiltinTool []byte

Expand All @@ -83,3 +89,16 @@ var (
//go:embed openai/responses/streaming/wrong_response_format.txtar
OaiResponsesStreamingWrongResponseFormat []byte
)

func Request(t *testing.T, fixture []byte) []byte {
t.Helper()

archive := txtar.Parse(fixture)
for _, f := range archive.Files {
if f.Name == "request" {
return f.Data
}
}
t.Fatal("request not found in fixture")
return []byte{}
}
357 changes: 357 additions & 0 deletions fixtures/openai/responses/streaming/codex_example.txtar

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion fixtures/openai/responses/streaming/stream_error.txtar
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
-- request --
{
"input": "hello",
"input": "hello_stream_error",
"model": "gpt-6.7",
"stream": true
}
Expand Down
2 changes: 1 addition & 1 deletion fixtures/openai/responses/streaming/stream_failure.txtar
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
-- request --
{
"input": "hello",
"input": "hello_stream_failure",
"model": "gpt-6.7",
"stream": true
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
-- request --
{
"input": "hello",
"input": "hello_wrong_format",
Comment thread
dannykopping marked this conversation as resolved.
"model": "gpt-6.7",
"stream": true
}
Expand All @@ -13,4 +13,8 @@ event: response.in_progress
data: {"type":"response.in_progress","response":{"id":"resp_123","object":"response","status":"in_progress","error":null,"output":[]},"sequence_number":2}

event: response.output_text.delta
data: { "wrong format": should fail
da
ta: { "wrong format": should be forwarded as received

event: response.completed
data: {"type":"response.completed","response":{"id":"resp_123","object":"response","created_at":1767874658,"status":"completed","background":false,"completed_at":1767874660,"error":null,"incomplete_details":null,"instructions":null,"max_output_tokens":null,"max_tool_calls":null,"model":"gpt-4o-mini-2024-07-18","output":[{"id":"msg_0f9c4b2f224d858000695fa063d4708197af73c2f37cb0b9d3","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":"Why did the scarecrow win an award?\n\nBecause he was outstanding in his field!"}],"role":"assistant"}],"parallel_tool_calls":true,"previous_response_id":null,"prompt_cache_key":null,"prompt_cache_retention":null,"reasoning":{"effort":null,"summary":null},"safety_identifier":null,"service_tier":"default","store":true,"temperature":1.0,"text":{"format":{"type":"text"},"verbosity":"medium"},"tool_choice":"auto","tools":[],"top_logprobs":0,"top_p":1.0,"truncation":"disabled","usage":{"input_tokens":11,"input_tokens_details":{"cached_tokens":0},"output_tokens":18,"output_tokens_details":{"reasoning_tokens":0},"total_tokens":29},"user":null,"metadata":{}},"sequence_number":24}
76 changes: 76 additions & 0 deletions intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
Expand All @@ -23,6 +24,7 @@ import (
"github.com/google/uuid"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.opentelemetry.io/otel/attribute"
)
Expand Down Expand Up @@ -144,6 +146,80 @@ func (i *responsesInterceptionBase) requestOptions(respCopy *responseCopier) []o
return opts
}

// lastUserPrompt returns last input message with "user" role
func (i *responsesInterceptionBase) lastUserPrompt() (string, error) {
if i == nil {
return "", errors.New("cannot get last user prompt: nil struct")
}
if i.req == nil {
return "", errors.New("cannot get last user prompt: nil req struct")
}

// 'input' field can be a string or array of objects:
// https://platform.openai.com/docs/api-reference/responses/create#responses_create-input

// Check string variant
if i.req.Input.OfString.Valid() {
return i.req.Input.OfString.Value, nil
}

// Fallback to parsing original bytes since golang SDK doesn't properly decode 'Input' field.
// If 'type' field of input item is not set it will be omitted from 'Input.OfInputItemList'
// It is an optional field according to API: https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message
Comment thread
dannykopping marked this conversation as resolved.
// example: fixtures/openai/responses/blocking/builtin_tool.txtar
inputItems := gjson.GetBytes(i.reqPayload, "input").Array()
for i := len(inputItems) - 1; i >= 0; i-- {
item := inputItems[i]
if item.Get("role").Str == "user" {
var sb strings.Builder

// content can be a string or array of objects:
// https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message-content
content := item.Get("content")
if content.Str != "" {
return content.Str, nil
}
for _, c := range content.Array() {
if c.Get("type").Str == "input_text" {
sb.WriteString(c.Get("text").Str)
}
}
if sb.Len() > 0 {
return sb.String(), nil
}
}
}

return "", errors.New("failed to find last user prompt")
}

func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string) {
prompt, err := i.lastUserPrompt()
if err != nil {
i.logger.Warn(ctx, "failed to get last user prompt", slog.Error(err))
return
}

if prompt == "" {
i.logger.Warn(ctx, "got empty last prompt, skipping prompt recording")
return
}

if responseID == "" {
i.logger.Warn(ctx, "got empty response ID, skipping prompt recording")
return
}

promptUsage := &recorder.PromptUsageRecord{
InterceptionID: i.ID().String(),
MsgID: responseID,
Prompt: prompt,
}
if err := i.recorder.RecordPromptUsage(ctx, promptUsage); err != nil {
i.logger.Warn(ctx, "failed to record prompt usage", slog.Error(err))
}
}

// responseCopier helper struct to send original response to the client
type responseCopier struct {
buff deltaBuffer
Expand Down
196 changes: 196 additions & 0 deletions intercept/responses/base_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
package responses

import (
"testing"

"cdr.dev/slog/v3"
"github.com/coder/aibridge/fixtures"
"github.com/coder/aibridge/internal/testutil"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)

func TestLastUserPrompt(t *testing.T) {
t.Parallel()

tests := []struct {
name string
reqPayload []byte
expected string
}{
{
name: "simple_string_input",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple),
expected: "tell me a joke",
},
{
name: "array_single_input_string",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingBuiltinTool),
expected: "Is 3 + 5 a prime number? Use the add function to calculate the sum.",
},
{
name: "array_multiple_items_content_objects",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesStreamingCodex),
expected: "hello",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

req := &ResponsesNewParamsWrapper{}
err := req.UnmarshalJSON(tc.reqPayload)
require.NoError(t, err)

base := &responsesInterceptionBase{
req: req,
reqPayload: tc.reqPayload,
}

prompt, err := base.lastUserPrompt()
require.NoError(t, err)
require.Equal(t, tc.expected, prompt)
})
}
}

func TestLastUserPromptErr(t *testing.T) {
t.Parallel()

t.Run("nil_struct", func(t *testing.T) {
t.Parallel()

var base *responsesInterceptionBase
prompt, err := base.lastUserPrompt()
require.Error(t, err)
require.Empty(t, prompt)
require.Contains(t, "cannot get last user prompt: nil struct", err.Error())
})

t.Run("nil_struct", func(t *testing.T) {
t.Parallel()

base := responsesInterceptionBase{}
prompt, err := base.lastUserPrompt()
require.Error(t, err)
require.Empty(t, prompt)
require.Contains(t, "cannot get last user prompt: nil req struct", err.Error())
})

tests := []struct {
name string
reqPayload []byte
wantErrMsg string
}{
{
name: "empty_input",
reqPayload: []byte(`{"model": "gpt-4o", "input": []}`),
wantErrMsg: "failed to find last user prompt",
},
{
name: "no_user_role",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "assistant", "content": "hello"}]}`),
wantErrMsg: "failed to find last user prompt",
},
{
name: "user_with_empty_content",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": ""}]}`),
wantErrMsg: "failed to find last user prompt",
},
{
name: "user_with_empty_content_array",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": []}]}`),
wantErrMsg: "failed to find last user prompt",
},
{
name: "user_with_non_input_text_content",
reqPayload: []byte(`{"model": "gpt-4o", "input": [{"role": "user", "content": [{"type": "input_image", "url": "http://example.com/img.png"}]}]}`),
wantErrMsg: "failed to find last user prompt",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

req := &ResponsesNewParamsWrapper{}
err := req.UnmarshalJSON(tc.reqPayload)
require.NoError(t, err)

base := &responsesInterceptionBase{
req: req,
reqPayload: tc.reqPayload,
}

prompt, err := base.lastUserPrompt()
require.Error(t, err)
require.Empty(t, prompt)
require.Contains(t, tc.wantErrMsg, err.Error())
})
}
}

func TestRecordPrompt(t *testing.T) {
t.Parallel()

tests := []struct {
name string
reqPayload []byte
responseID string
wantRecorded bool
wantPrompt string
}{
{
name: "records_prompt_successfully",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple),
responseID: "resp_123",
wantRecorded: true,
wantPrompt: "tell me a joke",
},
{
name: "skips_recording_on_empty_response_id",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple),
responseID: "",
wantRecorded: false,
},
{
name: "skips_recording_on_lastUserPrompt_error",
reqPayload: []byte(`{"model": "gpt-4o", "input": []}`),
responseID: "resp_123",
wantRecorded: false,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

req := &ResponsesNewParamsWrapper{}
err := req.UnmarshalJSON(tc.reqPayload)
require.NoError(t, err)

rec := &testutil.MockRecorder{}
id := uuid.New()
base := &responsesInterceptionBase{
id: id,
req: req,
reqPayload: tc.reqPayload,
recorder: rec,
logger: slog.Make(),
}

base.recordUserPrompt(t.Context(), tc.responseID)

prompts := rec.RecordedPromptUsages()
if tc.wantRecorded {
require.Len(t, prompts, 1)
require.Equal(t, id.String(), prompts[0].InterceptionID)
require.Equal(t, tc.responseID, prompts[0].MsgID)
require.Equal(t, tc.wantPrompt, prompts[0].Prompt)
} else {
require.Empty(t, prompts)
}
})
}
}
7 changes: 6 additions & 1 deletion intercept/responses/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r *
var respCopy responseCopier

opts := i.requestOptions(&respCopy)
_, upstreamErr := srv.New(ctx, i.req.ResponseNewParams, opts...)
response, upstreamErr := srv.New(ctx, i.req.ResponseNewParams, opts...)

// response could be nil eg. fixtures/openai/responses/blocking/wrong_response_format.txtar
if response != nil {
Comment thread
dannykopping marked this conversation as resolved.
i.recordUserPrompt(ctx, response.ID)
}

if upstreamErr != nil && !respCopy.responseReceived.Load() {
// no response received from upstream, return custom error
Expand Down
8 changes: 4 additions & 4 deletions intercept/responses/paramswrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ type ResponsesNewParamsWrapper struct {
Stream bool `json:"stream,omitempty"`
}

func (c *ResponsesNewParamsWrapper) UnmarshalJSON(raw []byte) error {
err := c.ResponseNewParams.UnmarshalJSON(raw)
func (r *ResponsesNewParamsWrapper) UnmarshalJSON(raw []byte) error {
err := r.ResponseNewParams.UnmarshalJSON(raw)
if err != nil {
return fmt.Errorf("failed to unmarshal response params: %w", err)
}

c.Stream = false
r.Stream = false
if stream := gjson.Get(string(raw), "stream"); stream.Bool() {
c.Stream = stream.Bool()
r.Stream = stream.Bool()
}
return nil
}
Loading