diff --git a/go.mod b/go.mod index e8abe95e9..37fce7016 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,8 @@ module charm.land/fantasy go 1.25.0 +toolchain go1.26.3 + require ( charm.land/x/vcr v0.1.1 cloud.google.com/go/auth v0.18.2 @@ -66,11 +68,11 @@ require ( go.opentelemetry.io/otel/metric v1.39.0 // indirect go.opentelemetry.io/otel/trace v1.39.0 // indirect go.yaml.in/yaml/v4 v4.0.0-rc.3 // indirect - golang.org/x/crypto v0.47.0 // indirect - golang.org/x/net v0.49.0 // indirect - golang.org/x/sync v0.19.0 // indirect - golang.org/x/sys v0.40.0 // indirect - golang.org/x/text v0.33.0 // indirect + golang.org/x/crypto v0.50.0 // indirect + golang.org/x/net v0.53.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/sys v0.43.0 // indirect + golang.org/x/text v0.36.0 // indirect golang.org/x/time v0.14.0 // indirect google.golang.org/api v0.264.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect diff --git a/go.sum b/go.sum index 8920215f4..65c2bc710 100644 --- a/go.sum +++ b/go.sum @@ -152,18 +152,18 @@ go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6 go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= go.yaml.in/yaml/v4 v4.0.0-rc.3 h1:3h1fjsh1CTAPjW7q/EMe+C8shx5d8ctzZTrLcs/j8Go= go.yaml.in/yaml/v4 v4.0.0-rc.3/go.mod h1:aZqd9kCMsGL7AuUv/m/PvWLdg5sjJsZ4oHDEnfPPfY0= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= -golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= diff --git a/providers/openai/computer_use.go b/providers/openai/computer_use.go index a5a871f7c..dd2e6c21c 100644 --- a/providers/openai/computer_use.go +++ b/providers/openai/computer_use.go @@ -11,12 +11,12 @@ import ( const computerUseToolID = "openai.computer_use" -// Type identifier for computer use metadata, registered in -// responses_options.go init(). +// TypeComputerUseMetadata is the type identifier for computer use metadata, +// registered in responses_options.go init(). const TypeComputerUseMetadata = Name + ".responses.computer_use_metadata" -// Type identifier for computer call output options, registered in -// responses_options.go init(). +// TypeComputerCallOutputOptions is the type identifier for computer call output +// options, registered in responses_options.go init(). const TypeComputerCallOutputOptions = Name + ".responses.computer_call_output_options" // ComputerUseMetadata stores the raw wire-format JSON of a computer_call diff --git a/providers/openai/openai_test.go b/providers/openai/openai_test.go index c344e6fad..66b4be8e9 100644 --- a/providers/openai/openai_test.go +++ b/providers/openai/openai_test.go @@ -3184,7 +3184,7 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) @@ -3212,7 +3212,7 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) @@ -3251,7 +3251,7 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) @@ -3274,7 +3274,7 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) @@ -3298,7 +3298,7 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) @@ -3331,7 +3331,7 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) @@ -3364,14 +3364,13 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) require.Len(t, input, 2) require.Empty(t, warnings) }) - } func TestParseContextTooLargeError(t *testing.T) { @@ -4015,7 +4014,7 @@ func TestResponsesToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) { t.Run("store false skips item reference", func(t *testing.T) { t.Parallel() - input, warnings, err := toResponsesPrompt(prompt, "system instructions", false) + input, warnings, err := toResponsesPrompt(prompt, "system instructions", false, false) require.NoError(t, err) @@ -4029,7 +4028,7 @@ func TestResponsesToPrompt_WebSearchProviderExecutedToolResults(t *testing.T) { t.Run("store true skips item reference", func(t *testing.T) { t.Parallel() - input, warnings, err := toResponsesPrompt(prompt, "system instructions", true) + input, warnings, err := toResponsesPrompt(prompt, "system instructions", true, false) require.NoError(t, err) @@ -4083,7 +4082,7 @@ func TestResponsesToPrompt_ReasoningWithStore(t *testing.T) { t.Run("store true emits item_reference for reasoning", func(t *testing.T) { t.Parallel() - input, warnings, err := toResponsesPrompt(prompt, "system", true) + input, warnings, err := toResponsesPrompt(prompt, "system", true, false) require.NoError(t, err) require.Empty(t, warnings) @@ -4134,7 +4133,7 @@ func TestResponsesToPrompt_ReasoningWithStore(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(noIDPrompt, "system", true) + input, warnings, err := toResponsesPrompt(noIDPrompt, "system", true, false) require.NoError(t, err) require.Empty(t, warnings) @@ -4151,7 +4150,7 @@ func TestResponsesToPrompt_ReasoningWithStore(t *testing.T) { t.Run("store false skips reasoning", func(t *testing.T) { t.Parallel() - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) require.Empty(t, warnings) @@ -4216,7 +4215,7 @@ func TestResponsesToPrompt_ReasoningWithWebSearchCombined(t *testing.T) { t.Run("store true pairs reasoning and web search", func(t *testing.T) { t.Parallel() - input, warnings, err := toResponsesPrompt(prompt, "system", true) + input, warnings, err := toResponsesPrompt(prompt, "system", true, false) require.NoError(t, err) require.Empty(t, warnings) @@ -4236,7 +4235,7 @@ func TestResponsesToPrompt_ReasoningWithWebSearchCombined(t *testing.T) { t.Run("store false skips both reasoning and provider tool call", func(t *testing.T) { t.Parallel() - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) require.Empty(t, warnings) @@ -4272,7 +4271,7 @@ func TestResponsesToPrompt_WebSearchRequiresReasoningReference(t *testing.T) { fantasy.TextPart{Text: "Search completed."}, }, }, - }, "system", true) + }, "system", true, false) require.NoError(t, err) require.Empty(t, warnings) @@ -4329,7 +4328,7 @@ func TestResponsesToPrompt_ReasoningWithFunctionCallCombined(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", true) + input, warnings, err := toResponsesPrompt(prompt, "system", true, false) require.NoError(t, err) require.Empty(t, warnings) @@ -4905,7 +4904,7 @@ func TestComputerUseGenerateRoundTrip_NonImageResult(t *testing.T) { }, } - input, warnings, err := toResponsesPrompt(prompt, "system", false) + input, warnings, err := toResponsesPrompt(prompt, "system", false, false) require.NoError(t, err) // Should warn about non-image result. diff --git a/providers/openai/responses_language_model.go b/providers/openai/responses_language_model.go index 077c27325..8a416fb45 100644 --- a/providers/openai/responses_language_model.go +++ b/providers/openai/responses_language_model.go @@ -166,7 +166,8 @@ func (o responsesLanguageModel) prepareParams(call fantasy.Call) (*responses.Res params.Store = param.NewOpt(false) } - if openaiOptions != nil && openaiOptions.PreviousResponseID != nil && *openaiOptions.PreviousResponseID != "" { + previousResponseID := openaiOptions != nil && openaiOptions.PreviousResponseID != nil && *openaiOptions.PreviousResponseID != "" + if previousResponseID { if err := validatePreviousResponseIDPrompt(call.Prompt); err != nil { return nil, warnings, err } @@ -177,7 +178,7 @@ func (o responsesLanguageModel) prepareParams(call fantasy.Call) (*responses.Res } storeEnabled := openaiOptions != nil && openaiOptions.Store != nil && *openaiOptions.Store - input, inputWarnings, err := toResponsesPrompt(call.Prompt, modelConfig.systemMessageMode, storeEnabled) + input, inputWarnings, err := toResponsesPrompt(call.Prompt, modelConfig.systemMessageMode, storeEnabled, previousResponseID) warnings = append(warnings, inputWarnings...) if err != nil { return nil, warnings, err @@ -400,7 +401,7 @@ func responsesUsage(resp responses.Response) fantasy.Usage { return usage } -func toResponsesPrompt(prompt fantasy.Prompt, systemMessageMode string, store bool) (responses.ResponseInputParam, []fantasy.CallWarning, error) { +func toResponsesPrompt(prompt fantasy.Prompt, systemMessageMode string, store bool, previousResponseID bool) (responses.ResponseInputParam, []fantasy.CallWarning, error) { var input responses.ResponseInputParam var warnings []fantasy.CallWarning @@ -741,7 +742,7 @@ func toResponsesPrompt(prompt fantasy.Prompt, systemMessageMode string, store bo } } - if err := validateResponsesInput(input); err != nil { + if err := validateResponsesInput(input, previousResponseID); err != nil { return nil, warnings, err } @@ -753,14 +754,14 @@ func isResponsesWebSearchToolCall(toolCallPart fantasy.ToolCallPart) bool { toolCallPart.ToolName == "web_search_preview" } -func validateResponsesInput(input responses.ResponseInputParam) error { - if err := validateResponsesFunctionCallOutputs(input); err != nil { +func validateResponsesInput(input responses.ResponseInputParam, previousResponseID bool) error { + if err := validateResponsesFunctionCallOutputs(input, previousResponseID); err != nil { return err } return validateResponsesItemReferences(input) } -func validateResponsesFunctionCallOutputs(input responses.ResponseInputParam) error { +func validateResponsesFunctionCallOutputs(input responses.ResponseInputParam, previousResponseID bool) error { type callState struct { calls int outputs int @@ -818,6 +819,9 @@ func validateResponsesFunctionCallOutputs(input responses.ResponseInputParam) er for _, callID := range outputIDs { state := states[callID] if state.calls == 0 { + if previousResponseID { + continue + } return fmt.Errorf("openai responses prompt has function_call_output without function_call for call_id %q", callID) } if state.firstOutput < state.firstCall { diff --git a/providers/openai/responses_params_test.go b/providers/openai/responses_params_test.go index 800eff6ba..ad34d4b6d 100644 --- a/providers/openai/responses_params_test.go +++ b/providers/openai/responses_params_test.go @@ -471,7 +471,7 @@ func TestPrepareParams_ValidatesFunctionCallOutputPairing(t *testing.T) { input, warnings, err := toResponsesPrompt(fantasy.Prompt{ testResponsesProviderToolResultMessage("ws_01"), - }, "system", false) + }, "system", false, false) require.NoError(t, err) require.Empty(t, warnings) require.Empty(t, input) @@ -498,7 +498,7 @@ func TestValidateResponsesInput_WebSearchReferenceRequiresReasoning(t *testing.T err := validateResponsesInput(responses.ResponseInputParam{ responses.ResponseInputItemParamOfItemReference("rs_valid"), responses.ResponseInputItemParamOfItemReference("ws_valid"), - }) + }, false) require.NoError(t, err) }) @@ -507,7 +507,7 @@ func TestValidateResponsesInput_WebSearchReferenceRequiresReasoning(t *testing.T err := validateResponsesInput(responses.ResponseInputParam{ responses.ResponseInputItemParamOfItemReference("ws_orphan"), - }) + }, false) require.EqualError(t, err, `openai responses prompt has web_search_call item_reference without preceding reasoning item_reference for item_id "ws_orphan"`) }) @@ -518,7 +518,7 @@ func TestValidateResponsesInput_WebSearchReferenceRequiresReasoning(t *testing.T responses.ResponseInputItemParamOfItemReference("rs_valid"), responses.ResponseInputItemParamOfMessage("text", responses.EasyInputMessageRoleAssistant), responses.ResponseInputItemParamOfItemReference("ws_orphan"), - }) + }, false) require.EqualError(t, err, `openai responses prompt has web_search_call item_reference without preceding reasoning item_reference for item_id "ws_orphan"`) }) } diff --git a/providers/openaicompat/language_model_hooks.go b/providers/openaicompat/language_model_hooks.go index b43cb4eea..643ab6e66 100644 --- a/providers/openaicompat/language_model_hooks.go +++ b/providers/openaicompat/language_model_hooks.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "maps" "strings" "charm.land/fantasy" @@ -47,6 +48,19 @@ func PrepareCallFunc(_ fantasy.LanguageModel, params *openaisdk.ChatCompletionNe if providerOptions.User != nil { params.User = param.NewOpt(*providerOptions.User) } + if providerOptions.ParallelToolCalls != nil { + params.ParallelToolCalls = param.NewOpt(*providerOptions.ParallelToolCalls) + } + if providerOptions.MaxCompletionTokens != nil { + params.MaxCompletionTokens = param.NewOpt(*providerOptions.MaxCompletionTokens) + } + if providerOptions.PromptCacheKey != nil { + params.PromptCacheKey = param.NewOpt(*providerOptions.PromptCacheKey) + } + + extraFields := make(map[string]any) + maps.Copy(extraFields, providerOptions.ExtraBody) + params.SetExtraFields(extraFields) return nil, nil } diff --git a/providers/openaicompat/openaicompat_test.go b/providers/openaicompat/openaicompat_test.go index a603e5e66..e33f60e7f 100644 --- a/providers/openaicompat/openaicompat_test.go +++ b/providers/openaicompat/openaicompat_test.go @@ -5,9 +5,55 @@ import ( "testing" "charm.land/fantasy" + openaisdk "github.com/charmbracelet/openai-go" "github.com/stretchr/testify/require" ) +func TestPrepareCallFunc_ProviderOptions(t *testing.T) { + t.Parallel() + + t.Run("should set chat completion provider options", func(t *testing.T) { + t.Parallel() + + params := openaisdk.ChatCompletionNewParams{} + warnings, err := PrepareCallFunc(nil, ¶ms, fantasy.Call{ + ProviderOptions: NewProviderOptions(&ProviderOptions{ + ParallelToolCalls: fantasy.Opt(false), + MaxCompletionTokens: fantasy.Opt(int64(255)), + PromptCacheKey: fantasy.Opt("test-cache-key-123"), + ExtraBody: map[string]any{ + "custom_field": "custom-value", + }, + }), + }) + + require.NoError(t, err) + require.Empty(t, warnings) + require.True(t, params.ParallelToolCalls.Valid()) + require.Equal(t, false, params.ParallelToolCalls.Value) + require.True(t, params.MaxCompletionTokens.Valid()) + require.Equal(t, int64(255), params.MaxCompletionTokens.Value) + require.True(t, params.PromptCacheKey.Valid()) + require.Equal(t, "test-cache-key-123", params.PromptCacheKey.Value) + require.Equal(t, map[string]any{ + "custom_field": "custom-value", + }, params.ExtraFields()) + }) + + t.Run("should leave unset chat completion provider options invalid", func(t *testing.T) { + t.Parallel() + + params := openaisdk.ChatCompletionNewParams{} + warnings, err := PrepareCallFunc(nil, ¶ms, fantasy.Call{}) + + require.NoError(t, err) + require.Empty(t, warnings) + require.False(t, params.ParallelToolCalls.Valid()) + require.False(t, params.MaxCompletionTokens.Valid()) + require.False(t, params.PromptCacheKey.Valid()) + }) +} + func TestToPromptFunc_ReasoningContent(t *testing.T) { t.Parallel() diff --git a/providers/openaicompat/provider_options.go b/providers/openaicompat/provider_options.go index afb037bf2..602a5303f 100644 --- a/providers/openaicompat/provider_options.go +++ b/providers/openaicompat/provider_options.go @@ -26,8 +26,13 @@ func init() { // ProviderOptions represents additional options for the OpenAI-compatible provider. type ProviderOptions struct { - User *string `json:"user"` - ReasoningEffort *openai.ReasoningEffort `json:"reasoning_effort"` + User *string `json:"user"` + ParallelToolCalls *bool `json:"parallel_tool_calls"` + ReasoningEffort *openai.ReasoningEffort `json:"reasoning_effort"` + MaxCompletionTokens *int64 `json:"max_completion_tokens"` + PromptCacheKey *string `json:"prompt_cache_key"` + // ExtraBody contains additional request body fields. + ExtraBody map[string]any `json:"extra_body,omitempty"` } // ReasoningData represents reasoning data for OpenAI-compatible provider. diff --git a/providertests/openai_computer_use_test.go b/providertests/openai_computer_use_test.go index eda625920..532818be1 100644 --- a/providertests/openai_computer_use_test.go +++ b/providertests/openai_computer_use_test.go @@ -95,7 +95,7 @@ func TestOpenAIComputerUse(t *testing.T) { result, err := agent.Generate(t.Context(), fantasy.AgentCall{ Prompt: "Take a screenshot of the desktop", - MaxOutputTokens: new(int64(4000)), + MaxOutputTokens: fantasy.Opt(int64(4000)), ProviderOptions: providerOpts, }) require.NoError(t, err) @@ -134,7 +134,7 @@ func TestOpenAIComputerUse(t *testing.T) { result, err := agent.Stream(t.Context(), fantasy.AgentStreamCall{ Prompt: "Take a screenshot of the desktop", - MaxOutputTokens: new(int64(4000)), + MaxOutputTokens: fantasy.Opt(int64(4000)), ProviderOptions: providerOpts, }) require.NoError(t, err) @@ -238,7 +238,7 @@ func TestOpenAIComputerUse_AllActions(t *testing.T) { result, err := agent.Generate(t.Context(), fantasy.AgentCall{ Prompt: prompt, - MaxOutputTokens: new(int64(16000)), + MaxOutputTokens: fantasy.Opt(int64(16000)), ProviderOptions: providerOpts, }) require.NoError(t, err)