diff --git a/cmd/cliflags/flags.go b/cmd/cliflags/flags.go index 1f14f626..b1615c2e 100644 --- a/cmd/cliflags/flags.go +++ b/cmd/cliflags/flags.go @@ -8,6 +8,8 @@ const ( AccessTokenFlag = "access-token" AnalyticsOptOut = "analytics-opt-out" BaseURIFlag = "base-uri" + CorsEnabledFlag = "cors-enabled" + CorsOriginFlag = "cors-origin" DataFlag = "data" DevStreamURIFlag = "dev-stream-uri" EmailsFlag = "emails" @@ -22,6 +24,8 @@ const ( AccessTokenFlagDescription = "LaunchDarkly access token with write-level access" AnalyticsOptOutDescription = "Opt out of analytics tracking" BaseURIFlagDescription = "LaunchDarkly base URI" + CorsEnabledFlagDescription = "Enable CORS headers for browser-based developer tools (default: false)" + CorsOriginFlagDescription = "Allowed CORS origin. Use '*' for all origins (default: '*')" DevStreamURIDescription = "Streaming service endpoint that the dev server uses to obtain authoritative flag data. This may be a LaunchDarkly or Relay Proxy endpoint" EnvironmentFlagDescription = "Default environment key" FlagFlagDescription = "Default feature flag key" @@ -36,6 +40,8 @@ func AllFlagsHelp() map[string]string { AccessTokenFlag: AccessTokenFlagDescription, AnalyticsOptOut: AnalyticsOptOutDescription, BaseURIFlag: BaseURIFlagDescription, + CorsEnabledFlag: CorsEnabledFlagDescription, + CorsOriginFlag: CorsOriginFlagDescription, DevStreamURIFlag: DevStreamURIDescription, EnvironmentFlag: EnvironmentFlagDescription, FlagFlag: FlagFlagDescription, diff --git a/cmd/config/testdata/help.golden b/cmd/config/testdata/help.golden index 4783dd60..1ca8abbf 100644 --- a/cmd/config/testdata/help.golden +++ b/cmd/config/testdata/help.golden @@ -4,6 +4,8 @@ Supported settings: - `access-token`: LaunchDarkly access token with write-level access - `analytics-opt-out`: Opt out of analytics tracking - `base-uri`: LaunchDarkly base URI +- `cors-enabled`: Enable CORS headers for browser-based developer tools (default: false) +- `cors-origin`: Allowed CORS origin. Use '*' for all origins (default: '*') - `dev-stream-uri`: Streaming service endpoint that the dev server uses to obtain authoritative flag data. This may be a LaunchDarkly or Relay Proxy endpoint - `environment`: Default environment key - `flag`: Default feature flag key diff --git a/cmd/dev_server/dev_server.go b/cmd/dev_server/dev_server.go index a630f722..549d4dd5 100644 --- a/cmd/dev_server/dev_server.go +++ b/cmd/dev_server/dev_server.go @@ -50,6 +50,20 @@ func NewDevServerCmd(client resources.Client, analyticsTrackerFn analytics.Track _ = viper.BindPFlag(cliflags.PortFlag, cmd.PersistentFlags().Lookup(cliflags.PortFlag)) + cmd.PersistentFlags().Bool( + cliflags.CorsEnabledFlag, + false, + cliflags.CorsEnabledFlagDescription, + ) + _ = viper.BindPFlag(cliflags.CorsEnabledFlag, cmd.PersistentFlags().Lookup(cliflags.CorsEnabledFlag)) + + cmd.PersistentFlags().String( + cliflags.CorsOriginFlag, + "*", + cliflags.CorsOriginFlagDescription, + ) + _ = viper.BindPFlag(cliflags.CorsOriginFlag, cmd.PersistentFlags().Lookup(cliflags.CorsOriginFlag)) + // Add subcommands here cmd.AddGroup(&cobra.Group{ID: "projects", Title: "Project commands:"}) cmd.AddCommand(NewListProjectsCmd(client)) diff --git a/cmd/dev_server/start_server.go b/cmd/dev_server/start_server.go index d14ceb38..a155d929 100644 --- a/cmd/dev_server/start_server.go +++ b/cmd/dev_server/start_server.go @@ -89,6 +89,8 @@ func startServer(client dev_server.Client) func(*cobra.Command, []string) error BaseURI: viper.GetString(cliflags.BaseURIFlag), DevStreamURI: viper.GetString(cliflags.DevStreamURIFlag), Port: viper.GetString(cliflags.PortFlag), + CorsEnabled: viper.GetBool(cliflags.CorsEnabledFlag), + CorsOrigin: viper.GetString(cliflags.CorsOriginFlag), InitialProjectSettings: initialSetting, } diff --git a/internal/dev_server/api/cors.go b/internal/dev_server/api/cors.go new file mode 100644 index 00000000..5173ec7c --- /dev/null +++ b/internal/dev_server/api/cors.go @@ -0,0 +1,31 @@ +package api + +import ( + "net/http" +) + +// CorsHeadersWithConfig provides configurable CORS support for the dev-server admin API endpoints. +// When enabled=false, no CORS headers are added. +// When enabled=true, CORS headers are added with the specified origin. +func CorsHeadersWithConfig(enabled bool, origin string) func(http.Handler) http.Handler { + return func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + if enabled { + writer.Header().Set("Access-Control-Allow-Origin", origin) + writer.Header().Set("Access-Control-Allow-Methods", "GET,POST,PUT,PATCH,DELETE,OPTIONS") + writer.Header().Set("Access-Control-Allow-Credentials", "true") + writer.Header().Set("Access-Control-Allow-Headers", "Accept,Content-Type,Content-Length,Accept-Encoding,Authorization,X-Requested-With") + writer.Header().Set("Access-Control-Expose-Headers", "Date,Content-Length") + writer.Header().Set("Access-Control-Max-Age", "300") + + // Handle preflight OPTIONS requests + if request.Method == http.MethodOptions { + writer.WriteHeader(http.StatusOK) + return + } + } + + handler.ServeHTTP(writer, request) + }) + } +} diff --git a/internal/dev_server/api/cors_test.go b/internal/dev_server/api/cors_test.go new file mode 100644 index 00000000..202008ac --- /dev/null +++ b/internal/dev_server/api/cors_test.go @@ -0,0 +1,63 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCorsHeadersWithConfig_Enabled(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("test response")) + require.NoError(t, err) + }) + + corsHandler := CorsHeadersWithConfig(true, "*")(handler) + + // Test GET request + req := httptest.NewRequest("GET", "/dev/projects", nil) + w := httptest.NewRecorder() + corsHandler.ServeHTTP(w, req) + + assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "GET,POST,PUT,PATCH,DELETE,OPTIONS", w.Header().Get("Access-Control-Allow-Methods")) + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestCorsHeadersWithConfig_OptionsRequest(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("Handler should not be called for OPTIONS request") + }) + + corsHandler := CorsHeadersWithConfig(true, "https://example.com")(handler) + + // Test OPTIONS preflight request + req := httptest.NewRequest("OPTIONS", "/dev/projects", nil) + w := httptest.NewRecorder() + corsHandler.ServeHTTP(w, req) + + assert.Equal(t, "https://example.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestCorsHeadersWithConfig_Disabled(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("test response")) + require.NoError(t, err) + }) + + corsHandler := CorsHeadersWithConfig(false, "*")(handler) + + // Test GET request with CORS disabled + req := httptest.NewRequest("GET", "/dev/projects", nil) + w := httptest.NewRecorder() + corsHandler.ServeHTTP(w, req) + + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"), "Expected no CORS headers when disabled") + assert.Equal(t, http.StatusOK, w.Code) +} diff --git a/internal/dev_server/dev_server.go b/internal/dev_server/dev_server.go index 0e8712dd..1d9233e9 100644 --- a/internal/dev_server/dev_server.go +++ b/internal/dev_server/dev_server.go @@ -29,6 +29,8 @@ type ServerParams struct { BaseURI string DevStreamURI string Port string + CorsEnabled bool + CorsOrigin string InitialProjectSettings model.InitialProjectSettings } @@ -65,6 +67,7 @@ func (c LDClient) RunServer(ctx context.Context, serverParams ServerParams) { r.PathPrefix("/ui/").Handler(http.StripPrefix("/ui/", ui.AssetHandler)) sdk.BindRoutes(r) handler := api.HandlerFromMux(apiServer, r) + handler = api.CorsHeadersWithConfig(serverParams.CorsEnabled, serverParams.CorsOrigin)(handler) handler = handlers.CombinedLoggingHandler(os.Stdout, handler) handler = handlers.RecoveryHandler(handlers.PrintRecoveryStack(true))(handler)