|
1 | 1 | package openai_compat |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "bytes" |
4 | 5 | "encoding/json" |
| 6 | + "fmt" |
| 7 | + "io" |
5 | 8 | "net/http" |
6 | 9 | "net/http/httptest" |
7 | 10 | "net/url" |
@@ -212,6 +215,132 @@ func TestProviderChat_HTTPError(t *testing.T) { |
212 | 215 | } |
213 | 216 | } |
214 | 217 |
|
| 218 | +func TestProviderChat_JSONHTTPErrorDoesNotReportHTML(t *testing.T) { |
| 219 | + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 220 | + w.Header().Set("Content-Type", "application/json") |
| 221 | + w.WriteHeader(http.StatusBadRequest) |
| 222 | + _, _ = w.Write([]byte(`{"error":"bad request"}`)) |
| 223 | + })) |
| 224 | + defer server.Close() |
| 225 | + |
| 226 | + p := NewProvider("key", server.URL, "") |
| 227 | + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) |
| 228 | + if err == nil { |
| 229 | + t.Fatal("expected error, got nil") |
| 230 | + } |
| 231 | + if !strings.Contains(err.Error(), "Status: 400") { |
| 232 | + t.Fatalf("expected status code in error, got %v", err) |
| 233 | + } |
| 234 | + if strings.Contains(err.Error(), "returned HTML instead of JSON") { |
| 235 | + t.Fatalf("expected non-HTML http error, got %v", err) |
| 236 | + } |
| 237 | +} |
| 238 | + |
| 239 | +func TestProviderChat_HTMLResponsesReturnHelpfulError(t *testing.T) { |
| 240 | + tests := []struct { |
| 241 | + name string |
| 242 | + contentType string |
| 243 | + statusCode int |
| 244 | + body string |
| 245 | + }{ |
| 246 | + { |
| 247 | + name: "html success response", |
| 248 | + contentType: "text/html; charset=utf-8", |
| 249 | + statusCode: http.StatusOK, |
| 250 | + body: "<!DOCTYPE html><html><body>gateway login</body></html>", |
| 251 | + }, |
| 252 | + { |
| 253 | + name: "html error response", |
| 254 | + contentType: "text/html; charset=utf-8", |
| 255 | + statusCode: http.StatusBadGateway, |
| 256 | + body: "<!DOCTYPE html><html><body>bad gateway</body></html>", |
| 257 | + }, |
| 258 | + { |
| 259 | + name: "mislabeled html success response", |
| 260 | + contentType: "application/json", |
| 261 | + statusCode: http.StatusOK, |
| 262 | + body: " \r\n\t<!DOCTYPE html><html><body>gateway login</body></html>", |
| 263 | + }, |
| 264 | + } |
| 265 | + |
| 266 | + for _, tt := range tests { |
| 267 | + t.Run(tt.name, func(t *testing.T) { |
| 268 | + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 269 | + w.Header().Set("Content-Type", tt.contentType) |
| 270 | + w.WriteHeader(tt.statusCode) |
| 271 | + _, _ = w.Write([]byte(tt.body)) |
| 272 | + })) |
| 273 | + defer server.Close() |
| 274 | + |
| 275 | + p := NewProvider("key", server.URL, "") |
| 276 | + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) |
| 277 | + if err == nil { |
| 278 | + t.Fatal("expected error, got nil") |
| 279 | + } |
| 280 | + if !strings.Contains(err.Error(), fmt.Sprintf("Status: %d", tt.statusCode)) { |
| 281 | + t.Fatalf("expected status code in error, got %v", err) |
| 282 | + } |
| 283 | + if !strings.Contains(err.Error(), "returned HTML instead of JSON") { |
| 284 | + t.Fatalf("expected helpful HTML error, got %v", err) |
| 285 | + } |
| 286 | + if !strings.Contains(err.Error(), "check api_base or proxy configuration") { |
| 287 | + t.Fatalf("expected configuration hint, got %v", err) |
| 288 | + } |
| 289 | + }) |
| 290 | + } |
| 291 | +} |
| 292 | + |
| 293 | +func TestProviderChat_SuccessResponseUsesStreamingDecoder(t *testing.T) { |
| 294 | + content := strings.Repeat("a", 1024) |
| 295 | + body := `{"choices":[{"message":{"content":"` + content + `"},"finish_reason":"stop"}]}` |
| 296 | + |
| 297 | + p := NewProvider("key", "https://example.com/v1", "") |
| 298 | + p.httpClient = &http.Client{ |
| 299 | + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { |
| 300 | + return &http.Response{ |
| 301 | + StatusCode: http.StatusOK, |
| 302 | + Header: http.Header{"Content-Type": []string{"application/json"}}, |
| 303 | + Body: &errAfterDataReadCloser{ |
| 304 | + data: []byte(body), |
| 305 | + chunkSize: 64, |
| 306 | + }, |
| 307 | + }, nil |
| 308 | + }), |
| 309 | + } |
| 310 | + |
| 311 | + out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) |
| 312 | + if err != nil { |
| 313 | + t.Fatalf("Chat() error = %v", err) |
| 314 | + } |
| 315 | + if out.Content != content { |
| 316 | + t.Fatalf("Content = %q, want %q", out.Content, content) |
| 317 | + } |
| 318 | +} |
| 319 | + |
| 320 | +func TestProviderChat_LargeHTMLResponsePreviewIsTruncated(t *testing.T) { |
| 321 | + body := append([]byte("<!DOCTYPE html><html><body>"), bytes.Repeat([]byte("A"), 2048)...) |
| 322 | + body = append(body, []byte("</body></html>")...) |
| 323 | + |
| 324 | + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 325 | + w.Header().Set("Content-Type", "text/html; charset=utf-8") |
| 326 | + w.WriteHeader(http.StatusBadGateway) |
| 327 | + _, _ = w.Write(body) |
| 328 | + })) |
| 329 | + defer server.Close() |
| 330 | + |
| 331 | + p := NewProvider("key", server.URL, "") |
| 332 | + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) |
| 333 | + if err == nil { |
| 334 | + t.Fatal("expected error, got nil") |
| 335 | + } |
| 336 | + if !strings.Contains(err.Error(), "Body: <!DOCTYPE html><html><body>") { |
| 337 | + t.Fatalf("expected html preview in error, got %v", err) |
| 338 | + } |
| 339 | + if !strings.Contains(err.Error(), "...") { |
| 340 | + t.Fatalf("expected truncated preview, got %v", err) |
| 341 | + } |
| 342 | +} |
| 343 | + |
215 | 344 | func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testing.T) { |
216 | 345 | var requestBody map[string]any |
217 | 346 |
|
@@ -399,6 +528,40 @@ func TestProvider_RequestTimeoutOverride(t *testing.T) { |
399 | 528 | } |
400 | 529 | } |
401 | 530 |
|
| 531 | +type roundTripperFunc func(*http.Request) (*http.Response, error) |
| 532 | + |
| 533 | +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { |
| 534 | + return f(r) |
| 535 | +} |
| 536 | + |
| 537 | +type errAfterDataReadCloser struct { |
| 538 | + data []byte |
| 539 | + chunkSize int |
| 540 | + offset int |
| 541 | +} |
| 542 | + |
| 543 | +func (r *errAfterDataReadCloser) Read(p []byte) (int, error) { |
| 544 | + if r.offset >= len(r.data) { |
| 545 | + return 0, io.ErrUnexpectedEOF |
| 546 | + } |
| 547 | + |
| 548 | + n := r.chunkSize |
| 549 | + if n <= 0 || n > len(p) { |
| 550 | + n = len(p) |
| 551 | + } |
| 552 | + remaining := len(r.data) - r.offset |
| 553 | + if n > remaining { |
| 554 | + n = remaining |
| 555 | + } |
| 556 | + copy(p, r.data[r.offset:r.offset+n]) |
| 557 | + r.offset += n |
| 558 | + return n, nil |
| 559 | +} |
| 560 | + |
| 561 | +func (r *errAfterDataReadCloser) Close() error { |
| 562 | + return nil |
| 563 | +} |
| 564 | + |
402 | 565 | func TestProvider_FunctionalOptionMaxTokensField(t *testing.T) { |
403 | 566 | p := NewProvider("key", "https://example.com/v1", "", WithMaxTokensField("max_completion_tokens")) |
404 | 567 | if p.maxTokensField != "max_completion_tokens" { |
|
0 commit comments