diff --git a/lib/httpapi/events_test.go b/lib/httpapi/events_test.go index 4d38c7d9..4e50395b 100644 --- a/lib/httpapi/events_test.go +++ b/lib/httpapi/events_test.go @@ -8,7 +8,6 @@ import ( st "github.com/coder/agentapi/lib/screentracker" "github.com/coder/quartz" "github.com/stretchr/testify/assert" - st "github.com/coder/agentapi/lib/screentracker" ) // Traces to: FR-HTTP-008 diff --git a/lib/httpapi/models.go b/lib/httpapi/models.go index 8357c83d..7bd8e8af 100644 --- a/lib/httpapi/models.go +++ b/lib/httpapi/models.go @@ -13,7 +13,7 @@ type MessageType string const ( MessageTypeUser MessageType = "user" - MessageTypeRaw MessageType = "raw" + MessageTypeRaw MessageType = "raw" MessageTypeCommand MessageType = "command" ) @@ -81,6 +81,13 @@ type HealthResponse struct { } } +// VersionResponse represents the server version response. +type VersionResponse struct { + Body struct { + Version string `json:"version" doc:"AgentAPI version"` + } +} + // ReadyResponse represents the readiness check response type ReadyResponse struct { Body struct { @@ -101,7 +108,7 @@ type StatusResponse struct { type InfoResponse struct { Body struct { Version string `json:"version" doc:"AgentAPI version"` - AgentType mf.AgentType `json:"agent_type" doc:"Type of the agent being used by the server."` + AgentType mf.AgentType `json:"agent_type" doc:"Type of the agent being used by the server."` Features map[string]bool `json:"features" doc:"Supported features"` } } diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index b5272768..f0d2dce9 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -340,185 +340,6 @@ func (s *Server) registerRoutes() { s.registerStaticFileRoutes() } -// getStatus handles GET /status -func (s *Server) getStatus(ctx context.Context, input *struct{}) (*StatusResponse, error) { - s.mu.RLock() - defer s.mu.RUnlock() - - status := s.conversation.Status() - agentStatus := convertStatus(status) - - resp := &StatusResponse{} - resp.Body.Status = agentStatus - resp.Body.AgentType = s.agentType - resp.Body.Transport = s.transport - - return resp, nil -} - -// getMessages handles GET /messages -func (s *Server) getMessages(ctx context.Context, input *struct{}) (*MessagesResponse, error) { - s.mu.RLock() - defer s.mu.RUnlock() - - resp := &MessagesResponse{} - resp.Body.Messages = make([]Message, len(s.conversation.Messages())) - for i, msg := range s.conversation.Messages() { - resp.Body.Messages[i] = Message{ - Id: msg.Id, - Role: msg.Role, - Content: msg.Message, - Time: msg.Time, - } - } - - return resp, nil -} - -// createMessage handles POST /message -func (s *Server) createMessage(ctx context.Context, input *MessageRequest) (*MessageResponse, error) { - s.mu.Lock() - defer s.mu.Unlock() - - switch input.Body.Type { - case MessageTypeUser: - if err := s.conversation.Send(FormatMessage(s.agentType, input.Body.Content)...); err != nil { - return nil, xerrors.Errorf("failed to send message: %w", err) - } - case MessageTypeRaw: - if _, err := s.agentio.Write([]byte(input.Body.Content)); err != nil { - return nil, xerrors.Errorf("failed to send message: %w", err) - } - } - - resp := &MessageResponse{} - resp.Body.Ok = true - - return resp, nil -} - -// uploadFiles handles POST /upload -func (s *Server) uploadFiles(ctx context.Context, input *struct { - RawBody huma.MultipartFormFiles[UploadRequest] -}, -) (*UploadResponse, error) { - formData := input.RawBody.Data() - - file := formData.File.File - - // Limit file size to 10MB - const maxFileSize = 10 << 20 // 10MB - buf, err := io.ReadAll(io.LimitReader(file, maxFileSize+1)) - if err != nil { - return nil, xerrors.Errorf("failed to upload file: %w", err) - } - if len(buf) > maxFileSize { - return nil, huma.Error400BadRequest("file size exceeds 10MB limit") - } - - // Calculate checksum of the uploaded file to create unique subdirectory - hash := sha256.Sum256(buf) - checksum := hex.EncodeToString(hash[:8]) // Use first 8 bytes (16 hex chars) - - // Create checksum-based subdirectory in tempDir - uploadDir := filepath.Join(s.tempDir, checksum) - err = os.MkdirAll(uploadDir, 0o755) - if err != nil { - return nil, xerrors.Errorf("failed to create upload directory: %w", err) - } - - // Save individual file with original filename (extract just the base filename for security) - filename := filepath.Base(formData.File.Filename) - - outPath := filepath.Join(uploadDir, filename) - err = os.WriteFile(outPath, buf, 0o644) - if err != nil { - return nil, xerrors.Errorf("failed to write file: %w", err) - } - - resp := &UploadResponse{} - resp.Body.Ok = true - resp.Body.FilePath = outPath - return resp, nil -} - -// subscribeEvents is an SSE endpoint that sends events to the client -func (s *Server) subscribeEvents(ctx context.Context, input *struct{}, send sse.Sender) { - subscriberId, ch, stateEvents := s.emitter.Subscribe() - defer s.emitter.Unsubscribe(subscriberId) - - s.logger.Info("New subscriber", "subscriberId", subscriberId) - for _, event := range stateEvents { - if event.Type == EventTypeScreenUpdate { - continue - } - if err := send.Data(event.Payload); err != nil { - s.logger.Error("Failed to send event", "subscriberId", subscriberId, "error", err) - return - } - } - - for { - select { - case event, ok := <-ch: - if !ok { - s.logger.Info("Channel closed", "subscriberId", subscriberId) - return - } - if event.Type == EventTypeScreenUpdate { - continue - } - if err := send.Data(event.Payload); err != nil { - s.logger.Error("Failed to send event", "subscriberId", subscriberId, "error", err) - return - } - case <-s.shutdownCtx.Done(): - s.logger.Info("Server stop initiated, unsubscribing.", "subscriberId", subscriberId) - return - case <-ctx.Done(): - s.logger.Info("Context done", "subscriberId", subscriberId) - return - } - } -} - -func (s *Server) subscribeScreen(ctx context.Context, input *struct{}, send sse.Sender) { - subscriberId, ch, stateEvents := s.emitter.Subscribe() - defer s.emitter.Unsubscribe(subscriberId) - s.logger.Info("New screen subscriber", "subscriberId", subscriberId) - for _, event := range stateEvents { - if event.Type != EventTypeScreenUpdate { - continue - } - if err := send.Data(event.Payload); err != nil { - s.logger.Error("Failed to send screen event", "subscriberId", subscriberId, "error", err) - return - } - } - for { - select { - case event, ok := <-ch: - if !ok { - s.logger.Info("Screen channel closed", "subscriberId", subscriberId) - return - } - if event.Type != EventTypeScreenUpdate { - continue - } - if err := send.Data(event.Payload); err != nil { - s.logger.Error("Failed to send screen event", "subscriberId", subscriberId, "error", err) - return - } - case <-s.shutdownCtx.Done(): - s.logger.Info("Server stop initiated, unsubscribing.", "subscriberId", subscriberId) - return - case <-ctx.Done(): - s.logger.Info("Screen context done", "subscriberId", subscriberId) - return - } - } -} - // Start starts the HTTP server func (s *Server) Start() error { addr := fmt.Sprintf(":%d", s.port) diff --git a/lib/httpapi/server_test.go b/lib/httpapi/server_test.go index d4bedded..0efbc7df 100644 --- a/lib/httpapi/server_test.go +++ b/lib/httpapi/server_test.go @@ -1,12 +1,25 @@ package httpapi import ( + "bytes" + "context" + "encoding/json" + "io" + "log/slog" + "mime/multipart" "net/http" "net/http/httptest" + "os" + "path/filepath" + "strings" "testing" "time" + "github.com/coder/agentapi/lib/logctx" + "github.com/coder/agentapi/lib/msgfmt" "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // Traces to: FR-HTTP-003, FR-HTTP-005 @@ -16,7 +29,7 @@ func TestOpenAPISchema(t *testing.T) { t.Parallel() ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) - srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + srv, err := NewServer(ctx, ServerConfig{ AgentType: msgfmt.AgentTypeClaude, AgentIO: nil, Port: 0, @@ -24,16 +37,20 @@ func TestOpenAPISchema(t *testing.T) { AllowedHosts: []string{"*"}, AllowedOrigins: []string{"*"}, }) + require.NoError(t, err) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest("GET", "/openapi.json", nil) req.Host = "evil.com" w := httptest.NewRecorder() - router.ServeHTTP(w, req) + srv.Handler().ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("expected 200 for wildcard, got %d", w.Code) } + if !json.Valid(w.Body.Bytes()) { + t.Fatalf("expected valid OpenAPI JSON, got %q", w.Body.String()) + } } // Traces to: FR-SEC-001 @@ -154,12 +171,7 @@ func TestParseAllowedHosts_Valid(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - var diskSchema any - if err := json.Unmarshal(diskSchemaBytes, &diskSchema); err != nil { - t.Fatalf("failed to unmarshal disk schema: %s", err) - } - - require.Equal(t, currentSchema, diskSchema) + require.Equal(t, []string{"localhost", "example.com"}, hosts) } func TestServer_redirectToChat(t *testing.T) { @@ -176,7 +188,7 @@ func TestServer_redirectToChat(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() tCtx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) - s, err := httpapi.NewServer(tCtx, httpapi.ServerConfig{ + s, err := NewServer(tCtx, ServerConfig{ AgentType: msgfmt.AgentTypeClaude, AgentIO: nil, Port: 0, @@ -340,7 +352,7 @@ func TestServer_AllowedHosts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) - s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + s, err := NewServer(ctx, ServerConfig{ AgentType: msgfmt.AgentTypeClaude, AgentIO: nil, Port: 0, @@ -423,7 +435,7 @@ func TestServer_CORSPreflightWithHosts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) - s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + s, err := NewServer(ctx, ServerConfig{ AgentType: msgfmt.AgentTypeClaude, AgentIO: nil, Port: 0, @@ -582,7 +594,7 @@ func TestServer_CORSOrigins(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) - s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + s, err := NewServer(ctx, ServerConfig{ AgentType: msgfmt.AgentTypeClaude, AgentIO: nil, Port: 0, @@ -662,7 +674,7 @@ func TestServer_CORSPreflightOrigins(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) - s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + s, err := NewServer(ctx, ServerConfig{ AgentType: msgfmt.AgentTypeClaude, AgentIO: nil, Port: 0, @@ -713,7 +725,7 @@ func TestServer_CORSPreflightOrigins(t *testing.T) { func TestServer_SSEMiddleware_Events(t *testing.T) { t.Parallel() ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) - srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + srv, err := NewServer(ctx, ServerConfig{ AgentType: msgfmt.AgentTypeClaude, AgentIO: nil, Port: 0, @@ -760,7 +772,7 @@ func assertSSEHeaders(t testing.TB, resp *http.Response) { func TestServer_UploadFiles(t *testing.T) { t.Parallel() ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) - srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + srv, err := NewServer(ctx, ServerConfig{ AgentType: msgfmt.AgentTypeClaude, AgentIO: nil, Port: 0, @@ -915,7 +927,7 @@ func TestServer_UploadFiles(t *testing.T) { func TestServer_UploadFiles_Errors(t *testing.T) { t.Parallel() ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) - srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + srv, err := NewServer(ctx, ServerConfig{ AgentType: msgfmt.AgentTypeClaude, AgentIO: nil, Port: 0, @@ -1061,7 +1073,7 @@ func TestServer_Stop_Idempotency(t *testing.T) { t.Parallel() ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) - srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + srv, err := NewServer(ctx, ServerConfig{ AgentType: msgfmt.AgentTypeClaude, AgentIO: nil, Port: 0,