diff --git a/.cspell.json b/.cspell.json index 76865b83..0bb6cde2 100644 --- a/.cspell.json +++ b/.cspell.json @@ -4,6 +4,7 @@ "words": [ "Diátaxis", "GOFLAGS", + "Millis", "apiregistration", "apiserverapp", "apiserver", diff --git a/api/aggregation/v1alpha1/types.go b/api/aggregation/v1alpha1/types.go index 801687d8..5f4204e8 100644 --- a/api/aggregation/v1alpha1/types.go +++ b/api/aggregation/v1alpha1/types.go @@ -4,14 +4,34 @@ import metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" // CoderWorkspaceSpec defines the desired state of a CoderWorkspace. type CoderWorkspaceSpec struct { - // Running indicates whether the workspace should be running. + // Organization is the Coder organization name. + Organization string `json:"organization,omitempty"` + + // TemplateName resolves via TemplateByName(organization, templateName). + TemplateName string `json:"templateName,omitempty"` + + // TemplateVersionID optionally pins to a specific template version. + TemplateVersionID string `json:"templateVersionID,omitempty"` + + // Running drives start/stop via CreateWorkspaceBuild. Running bool `json:"running"` + + TTLMillis *int64 `json:"ttlMillis,omitempty"` + AutostartSchedule *string `json:"autostartSchedule,omitempty"` } // CoderWorkspaceStatus defines the observed state of a CoderWorkspace. type CoderWorkspaceStatus struct { - // AutoShutdown is the next planned shutdown time for the workspace. + ID string `json:"id,omitempty"` + OwnerName string `json:"ownerName,omitempty"` + OrganizationName string `json:"organizationName,omitempty"` + TemplateName string `json:"templateName,omitempty"` + + LatestBuildID string `json:"latestBuildID,omitempty"` + LatestBuildStatus string `json:"latestBuildStatus,omitempty"` + AutoShutdown *metav1.Time `json:"autoShutdown,omitempty"` + LastUsedAt *metav1.Time `json:"lastUsedAt,omitempty"` } // +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object @@ -19,6 +39,7 @@ type CoderWorkspaceStatus struct { // +kubebuilder:subresource:status // CoderWorkspace is the schema for Coder workspace resources. +// metadata.name is ... type CoderWorkspace struct { metav1.TypeMeta `json:",inline"` metav1.ObjectMeta `json:"metadata,omitempty"` @@ -39,13 +60,29 @@ type CoderWorkspaceList struct { // CoderTemplateSpec defines the desired state of a CoderTemplate. type CoderTemplateSpec struct { - // Running indicates whether the template should be marked as running. - Running bool `json:"running"` + // Organization is the Coder organization name (must match the organization prefix in metadata.name). + Organization string `json:"organization"` + + // VersionID is the Coder template version UUID used on creation (required for CREATE). + VersionID string `json:"versionID"` + + DisplayName string `json:"displayName,omitempty"` + Description string `json:"description,omitempty"` + Icon string `json:"icon,omitempty"` + + // Running is a legacy flag retained temporarily for in-repo callers that still read template run-state directly. + Running bool `json:"running,omitempty"` } // CoderTemplateStatus defines the observed state of a CoderTemplate. type CoderTemplateStatus struct { - // AutoShutdown is the next planned shutdown time for workspaces created by this template. + ID string `json:"id,omitempty"` + OrganizationName string `json:"organizationName,omitempty"` + ActiveVersionID string `json:"activeVersionID,omitempty"` + Deprecated bool `json:"deprecated,omitempty"` + UpdatedAt *metav1.Time `json:"updatedAt,omitempty"` + + // AutoShutdown is a legacy timestamp retained temporarily for in-repo callers that still surface template shutdown timestamps. AutoShutdown *metav1.Time `json:"autoShutdown,omitempty"` } @@ -54,6 +91,7 @@ type CoderTemplateStatus struct { // +kubebuilder:subresource:status // CoderTemplate is the schema for Coder template resources. +// metadata.name is .. type CoderTemplate struct { metav1.TypeMeta `json:",inline"` metav1.ObjectMeta `json:"metadata,omitempty"` diff --git a/api/aggregation/v1alpha1/zz_generated.deepcopy.go b/api/aggregation/v1alpha1/zz_generated.deepcopy.go index 8d520628..7e7de573 100644 --- a/api/aggregation/v1alpha1/zz_generated.deepcopy.go +++ b/api/aggregation/v1alpha1/zz_generated.deepcopy.go @@ -89,6 +89,10 @@ func (in *CoderTemplateSpec) DeepCopy() *CoderTemplateSpec { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *CoderTemplateStatus) DeepCopyInto(out *CoderTemplateStatus) { *out = *in + if in.UpdatedAt != nil { + in, out := &in.UpdatedAt, &out.UpdatedAt + *out = (*in).DeepCopy() + } if in.AutoShutdown != nil { in, out := &in.AutoShutdown, &out.AutoShutdown *out = (*in).DeepCopy() @@ -111,7 +115,7 @@ func (in *CoderWorkspace) DeepCopyInto(out *CoderWorkspace) { *out = *in out.TypeMeta = in.TypeMeta in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) - out.Spec = in.Spec + in.Spec.DeepCopyInto(&out.Spec) in.Status.DeepCopyInto(&out.Status) return } @@ -170,6 +174,16 @@ func (in *CoderWorkspaceList) DeepCopyObject() runtime.Object { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *CoderWorkspaceSpec) DeepCopyInto(out *CoderWorkspaceSpec) { *out = *in + if in.TTLMillis != nil { + in, out := &in.TTLMillis, &out.TTLMillis + *out = new(int64) + **out = **in + } + if in.AutostartSchedule != nil { + in, out := &in.AutostartSchedule, &out.AutostartSchedule + *out = new(string) + **out = **in + } return } @@ -190,6 +204,10 @@ func (in *CoderWorkspaceStatus) DeepCopyInto(out *CoderWorkspaceStatus) { in, out := &in.AutoShutdown, &out.AutoShutdown *out = (*in).DeepCopy() } + if in.LastUsedAt != nil { + in, out := &in.LastUsedAt, &out.LastUsedAt + *out = (*in).DeepCopy() + } return } diff --git a/app_dispatch.go b/app_dispatch.go index 5b471801..6a500aa6 100644 --- a/app_dispatch.go +++ b/app_dispatch.go @@ -1,8 +1,12 @@ package main import ( + "context" "flag" "fmt" + "net/url" + "strings" + "time" ctrl "sigs.k8s.io/controller-runtime" @@ -15,24 +19,79 @@ const supportedAppModes = "controller, aggregated-apiserver, mcp-http" var ( runControllerApp = controllerapp.Run - runAggregatedAPIServerApp = apiserverapp.Run - runMCPHTTPApp = mcpapp.RunHTTP - setupSignalHandler = ctrl.SetupSignalHandler + runAggregatedAPIServerApp = func(ctx context.Context, opts apiserverapp.Options) error { + return apiserverapp.RunWithOptions(ctx, opts) + } + runMCPHTTPApp = mcpapp.RunHTTP + setupSignalHandler = ctrl.SetupSignalHandler ) func run(args []string) error { fs := flag.NewFlagSet("coder-k8s", flag.ContinueOnError) - var appMode string + var ( + appMode string + coderURL string + coderSessionToken string + coderNamespace string + coderRequestTimeout time.Duration + ) fs.StringVar(&appMode, "app", "", "Application mode (controller, aggregated-apiserver, mcp-http)") + fs.StringVar( + &coderSessionToken, + "coder-session-token", + "", + "Admin session token for the backing Coder deployment", + ) + fs.StringVar( + &coderURL, + "coder-url", + "", + "Coder deployment URL (fallback when CoderControlPlane status URL is unavailable)", + ) + fs.StringVar( + &coderNamespace, + "coder-namespace", + "", + "Restrict the aggregated API server to serve only this Kubernetes namespace", + ) + fs.DurationVar( + &coderRequestTimeout, + "coder-request-timeout", + 30*time.Second, + "Timeout for Coder SDK API requests", + ) if err := fs.Parse(args); err != nil { return err } + if coderURL != "" { + parsedCoderURL, err := url.Parse(coderURL) + if err != nil { + return fmt.Errorf("assertion failed: invalid --coder-url %q: %w", coderURL, err) + } + if parsedCoderURL.Scheme == "" || parsedCoderURL.Host == "" { + return fmt.Errorf( + "assertion failed: invalid --coder-url %q: must include scheme and host (for example, https://coder.example.com)", + coderURL, + ) + } + scheme := strings.ToLower(parsedCoderURL.Scheme) + if scheme != "http" && scheme != "https" { + return fmt.Errorf("assertion failed: invalid --coder-url %q: scheme must be http or https", coderURL) + } + } + switch appMode { case "controller": return runControllerApp(setupSignalHandler()) case "aggregated-apiserver": - return runAggregatedAPIServerApp(setupSignalHandler()) + opts := apiserverapp.Options{ + CoderURL: coderURL, + CoderSessionToken: coderSessionToken, + CoderNamespace: coderNamespace, + CoderRequestTimeout: coderRequestTimeout, + } + return runAggregatedAPIServerApp(setupSignalHandler(), opts) case "mcp-http": return runMCPHTTPApp(setupSignalHandler()) case "": diff --git a/config/crd/bases/aggregation.coder.com_codertemplates.yaml b/config/crd/bases/aggregation.coder.com_codertemplates.yaml index f016c7d9..7d587483 100644 --- a/config/crd/bases/aggregation.coder.com_codertemplates.yaml +++ b/config/crd/bases/aggregation.coder.com_codertemplates.yaml @@ -17,7 +17,9 @@ spec: - name: v1alpha1 schema: openAPIV3Schema: - description: CoderTemplate is the schema for Coder template resources. + description: |- + CoderTemplate is the schema for Coder template resources. + metadata.name is .. properties: apiVersion: description: |- @@ -39,19 +41,45 @@ spec: spec: description: CoderTemplateSpec defines the desired state of a CoderTemplate. properties: + description: + type: string + displayName: + type: string + icon: + type: string + organization: + description: Organization is the Coder organization name (must match + the organization prefix in metadata.name). + type: string running: - description: Running indicates whether the template should be marked - as running. + description: Running is a legacy flag retained temporarily for in-repo + callers that still read template run-state directly. type: boolean + versionID: + description: VersionID is the Coder template version UUID used on + creation (required for CREATE). + type: string required: - - running + - organization + - versionID type: object status: description: CoderTemplateStatus defines the observed state of a CoderTemplate. properties: + activeVersionID: + type: string autoShutdown: - description: AutoShutdown is the next planned shutdown time for workspaces - created by this template. + description: AutoShutdown is a legacy timestamp retained temporarily + for in-repo callers that still surface template shutdown timestamps. + format: date-time + type: string + deprecated: + type: boolean + id: + type: string + organizationName: + type: string + updatedAt: format: date-time type: string type: object diff --git a/config/crd/bases/aggregation.coder.com_coderworkspaces.yaml b/config/crd/bases/aggregation.coder.com_coderworkspaces.yaml index ece01c17..50abcb09 100644 --- a/config/crd/bases/aggregation.coder.com_coderworkspaces.yaml +++ b/config/crd/bases/aggregation.coder.com_coderworkspaces.yaml @@ -17,7 +17,9 @@ spec: - name: v1alpha1 schema: openAPIV3Schema: - description: CoderWorkspace is the schema for Coder workspace resources. + description: |- + CoderWorkspace is the schema for Coder workspace resources. + metadata.name is ... properties: apiVersion: description: |- @@ -39,9 +41,25 @@ spec: spec: description: CoderWorkspaceSpec defines the desired state of a CoderWorkspace. properties: + autostartSchedule: + type: string + organization: + description: Organization is the Coder organization name. + type: string running: - description: Running indicates whether the workspace should be running. + description: Running drives start/stop via CreateWorkspaceBuild. type: boolean + templateName: + description: TemplateName resolves via TemplateByName(organization, + templateName). + type: string + templateVersionID: + description: TemplateVersionID optionally pins to a specific template + version. + type: string + ttlMillis: + format: int64 + type: integer required: - running type: object @@ -49,10 +67,23 @@ spec: description: CoderWorkspaceStatus defines the observed state of a CoderWorkspace. properties: autoShutdown: - description: AutoShutdown is the next planned shutdown time for the - workspace. format: date-time type: string + id: + type: string + lastUsedAt: + format: date-time + type: string + latestBuildID: + type: string + latestBuildStatus: + type: string + organizationName: + type: string + ownerName: + type: string + templateName: + type: string type: object type: object served: true diff --git a/docs/reference/api/codertemplate.md b/docs/reference/api/codertemplate.md index edb32cb4..a4b3f200 100644 --- a/docs/reference/api/codertemplate.md +++ b/docs/reference/api/codertemplate.md @@ -13,17 +13,25 @@ | Field | Type | Description | | --- | --- | --- | -| `spec.running` | `bool` | Running indicates whether the template should be marked as running. | +| `spec.organization` | `string` | Organization is the Coder organization name (must match the organization prefix in metadata.name). | +| `spec.versionID` | `string` | VersionID is the Coder template version UUID used on creation (required for CREATE). | +| `spec.displayName` | `string` | | +| `spec.description` | `string` | | +| `spec.icon` | `string` | | +| `spec.running` | `bool` | Running is a legacy flag retained temporarily for in-repo callers that still read template run-state directly. | ## Status | Field | Type | Description | | --- | --- | --- | -| `status.autoShutdown` | `metav1.Time` | AutoShutdown is the next planned shutdown time for workspaces created by this template. | +| `status.id` | `string` | | +| `status.organizationName` | `string` | | +| `status.activeVersionID` | `string` | | +| `status.deprecated` | `bool` | | +| `status.updatedAt` | `metav1.Time` | | +| `status.autoShutdown` | `metav1.Time` | AutoShutdown is a legacy timestamp retained temporarily for in-repo callers that still surface template shutdown timestamps. | ## Source - Go type: `api/aggregation/v1alpha1/types.go` -- Storage implementation: `internal/aggregated/storage/template.go` - - APIService registration manifest: `deploy/apiserver-apiservice.yaml` diff --git a/docs/reference/api/coderworkspace.md b/docs/reference/api/coderworkspace.md index e92798f1..6e7a10db 100644 --- a/docs/reference/api/coderworkspace.md +++ b/docs/reference/api/coderworkspace.md @@ -13,17 +13,27 @@ | Field | Type | Description | | --- | --- | --- | -| `spec.running` | `bool` | Running indicates whether the workspace should be running. | +| `spec.organization` | `string` | Organization is the Coder organization name. | +| `spec.templateName` | `string` | TemplateName resolves via TemplateByName(organization, templateName). | +| `spec.templateVersionID` | `string` | TemplateVersionID optionally pins to a specific template version. | +| `spec.running` | `bool` | Running drives start/stop via CreateWorkspaceBuild. | +| `spec.ttlMillis` | `int64` | | +| `spec.autostartSchedule` | `string` | | ## Status | Field | Type | Description | | --- | --- | --- | -| `status.autoShutdown` | `metav1.Time` | AutoShutdown is the next planned shutdown time for the workspace. | +| `status.id` | `string` | | +| `status.ownerName` | `string` | | +| `status.organizationName` | `string` | | +| `status.templateName` | `string` | | +| `status.latestBuildID` | `string` | | +| `status.latestBuildStatus` | `string` | | +| `status.autoShutdown` | `metav1.Time` | | +| `status.lastUsedAt` | `metav1.Time` | | ## Source - Go type: `api/aggregation/v1alpha1/types.go` -- Storage implementation: `internal/aggregated/storage/workspace.go` - - APIService registration manifest: `deploy/apiserver-apiservice.yaml` diff --git a/internal/aggregated/coder/config.go b/internal/aggregated/coder/config.go new file mode 100644 index 00000000..2c25b7c5 --- /dev/null +++ b/internal/aggregated/coder/config.go @@ -0,0 +1,54 @@ +// Package coder provides shared Coder backend helpers for the aggregated API server. +package coder + +import ( + "fmt" + "net/url" + "time" + + "github.com/coder/coder/v2/codersdk" +) + +const defaultRequestTimeout = 30 * time.Second + +// Config describes how to construct a Coder SDK client. +type Config struct { + CoderURL *url.URL + SessionToken string + RequestTimeout time.Duration +} + +// NewSDKClient creates a configured Coder SDK client from cfg. +func NewSDKClient(cfg Config) (*codersdk.Client, error) { + if cfg.CoderURL == nil { + return nil, fmt.Errorf("assertion failed: coder URL must not be nil") + } + if cfg.SessionToken == "" { + return nil, fmt.Errorf("assertion failed: session token must not be empty") + } + + requestTimeout := cfg.RequestTimeout + switch { + case requestTimeout < 0: + return nil, fmt.Errorf("assertion failed: request timeout must not be negative") + case requestTimeout == 0: + requestTimeout = defaultRequestTimeout + } + + coderURL := *cfg.CoderURL + client := codersdk.New(&coderURL) + if client == nil { + return nil, fmt.Errorf("assertion failed: coder SDK client is nil after successful construction") + } + if client.HTTPClient == nil { + return nil, fmt.Errorf("assertion failed: coder SDK HTTP client is nil after successful construction") + } + + client.HTTPClient.Timeout = requestTimeout + client.SetSessionToken(cfg.SessionToken) + if client.SessionToken() == "" { + return nil, fmt.Errorf("assertion failed: coder SDK session token is empty after successful configuration") + } + + return client, nil +} diff --git a/internal/aggregated/coder/config_test.go b/internal/aggregated/coder/config_test.go new file mode 100644 index 00000000..18cc9bd5 --- /dev/null +++ b/internal/aggregated/coder/config_test.go @@ -0,0 +1,110 @@ +package coder + +import ( + "net/url" + "strings" + "testing" + "time" +) + +func TestNewSDKClient(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config Config + wantErrContains string + wantTimeout time.Duration + }{ + { + name: "defaults timeout when omitted", + config: Config{ + CoderURL: mustParseURL(t, "https://coder.example.com"), + SessionToken: "session-token", + }, + wantTimeout: defaultRequestTimeout, + }, + { + name: "uses explicit timeout", + config: Config{ + CoderURL: mustParseURL(t, "https://coder.example.com"), + SessionToken: "session-token", + RequestTimeout: 45 * time.Second, + }, + wantTimeout: 45 * time.Second, + }, + { + name: "rejects nil coder URL", + config: Config{ + SessionToken: "session-token", + }, + wantErrContains: "assertion failed: coder URL must not be nil", + }, + { + name: "rejects empty session token", + config: Config{ + CoderURL: mustParseURL(t, "https://coder.example.com"), + }, + wantErrContains: "assertion failed: session token must not be empty", + }, + { + name: "rejects negative timeout", + config: Config{ + CoderURL: mustParseURL(t, "https://coder.example.com"), + SessionToken: "session-token", + RequestTimeout: -1 * time.Second, + }, + wantErrContains: "assertion failed: request timeout must not be negative", + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + client, err := NewSDKClient(testCase.config) + if testCase.wantErrContains != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", testCase.wantErrContains) + } + if !strings.Contains(err.Error(), testCase.wantErrContains) { + t.Fatalf("expected error to contain %q, got %q", testCase.wantErrContains, err.Error()) + } + return + } + + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if client == nil { + t.Fatal("expected non-nil client") + } + if client.HTTPClient == nil { + t.Fatal("expected non-nil HTTP client") + } + if got, want := client.HTTPClient.Timeout, testCase.wantTimeout; got != want { + t.Fatalf("expected timeout %s, got %s", want, got) + } + if got, want := client.SessionToken(), testCase.config.SessionToken; got != want { + t.Fatalf("expected session token %q, got %q", want, got) + } + if got, want := client.URL.String(), testCase.config.CoderURL.String(); got != want { + t.Fatalf("expected URL %q, got %q", want, got) + } + }) + } +} + +func mustParseURL(t *testing.T, rawURL string) *url.URL { + t.Helper() + + parsedURL, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("parse URL %q: %v", rawURL, err) + } + if parsedURL == nil { + t.Fatalf("parse URL %q returned nil URL", rawURL) + } + + return parsedURL +} diff --git a/internal/aggregated/coder/errors.go b/internal/aggregated/coder/errors.go new file mode 100644 index 00000000..2e92ede1 --- /dev/null +++ b/internal/aggregated/coder/errors.go @@ -0,0 +1,83 @@ +package coder + +import ( + "errors" + "fmt" + "net/http" + "strings" + + "github.com/coder/coder/v2/codersdk" + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime/schema" +) + +// MapCoderError converts Coder SDK errors to Kubernetes API errors. +func MapCoderError(err error, resource schema.GroupResource, name string) error { + if err == nil { + return fmt.Errorf("assertion failed: error must not be nil") + } + if resource.Empty() { + return fmt.Errorf("assertion failed: resource must not be empty") + } + if name == "" { + return fmt.Errorf("assertion failed: resource name must not be empty") + } + + var coderErr *codersdk.Error + if !errors.As(err, &coderErr) { + return apierrors.NewInternalError(err) + } + + statusCode := coderErr.StatusCode() + message := coderErrorMessage(coderErr, err) + + switch statusCode { + case http.StatusNotFound: + return apierrors.NewNotFound(resource, name) + case http.StatusForbidden: + return apierrors.NewForbidden(resource, name, err) + case http.StatusConflict: + if isAlreadyExistsConflict(coderErr) { + return apierrors.NewAlreadyExists(resource, name) + } + return apierrors.NewConflict(resource, name, err) + case http.StatusBadRequest, http.StatusUnprocessableEntity: + return apierrors.NewBadRequest(message) + case http.StatusUnauthorized: + return apierrors.NewUnauthorized(message) + case http.StatusTooManyRequests: + return apierrors.NewTooManyRequests(message, 0) + default: + if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError { + return apierrors.NewBadRequest(message) + } + + return apierrors.NewInternalError(err) + } +} + +func coderErrorMessage(coderErr *codersdk.Error, fallback error) string { + if coderErr == nil { + panic("assertion failed: coder error must not be nil") + } + if fallback == nil { + panic("assertion failed: fallback error must not be nil") + } + + message := strings.TrimSpace(coderErr.Message) + if message != "" { + return message + } + + return fallback.Error() +} + +func isAlreadyExistsConflict(err *codersdk.Error) bool { + if err == nil { + panic("assertion failed: coder error must not be nil") + } + + message := strings.ToLower(err.Message) + + return strings.Contains(message, "already exists") +} diff --git a/internal/aggregated/coder/errors_test.go b/internal/aggregated/coder/errors_test.go new file mode 100644 index 00000000..461f85f1 --- /dev/null +++ b/internal/aggregated/coder/errors_test.go @@ -0,0 +1,215 @@ +package coder + +import ( + "errors" + "net/http" + "strings" + "testing" + + aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" + "github.com/coder/coder/v2/codersdk" + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime/schema" +) + +func TestMapCoderError(t *testing.T) { + t.Parallel() + + resource := aggregationv1alpha1.Resource("coderworkspaces") + name := "acme.alice.dev" + + tests := []struct { + name string + err error + assertMapping func(t *testing.T, err error) + }{ + { + name: "maps not found", + err: codersdk.NewTestError(http.StatusNotFound, http.MethodGet, "https://coder.example.com"), + assertMapping: func(t *testing.T, err error) { + t.Helper() + if !apierrors.IsNotFound(err) { + t.Fatalf("expected NotFound, got %v", err) + } + }, + }, + { + name: "maps forbidden", + err: codersdk.NewTestError(http.StatusForbidden, http.MethodGet, "https://coder.example.com"), + assertMapping: func(t *testing.T, err error) { + t.Helper() + if !apierrors.IsForbidden(err) { + t.Fatalf("expected Forbidden, got %v", err) + } + }, + }, + { + name: "maps bad request", + err: withCoderMessage( + codersdk.NewTestError(http.StatusBadRequest, http.MethodGet, "https://coder.example.com"), + "bad workspace request", + ), + assertMapping: func(t *testing.T, err error) { + t.Helper() + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest, got %v", err) + } + }, + }, + { + name: "maps unprocessable entity to bad request", + err: withCoderMessage( + codersdk.NewTestError(http.StatusUnprocessableEntity, http.MethodGet, "https://coder.example.com"), + "invalid workspace transition", + ), + assertMapping: func(t *testing.T, err error) { + t.Helper() + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest, got %v", err) + } + }, + }, + { + name: "maps unauthorized", + err: withCoderMessage( + codersdk.NewTestError(http.StatusUnauthorized, http.MethodGet, "https://coder.example.com"), + "invalid session token", + ), + assertMapping: func(t *testing.T, err error) { + t.Helper() + if !apierrors.IsUnauthorized(err) { + t.Fatalf("expected Unauthorized, got %v", err) + } + }, + }, + { + name: "maps too many requests", + err: withCoderMessage( + codersdk.NewTestError(http.StatusTooManyRequests, http.MethodGet, "https://coder.example.com"), + "rate limited", + ), + assertMapping: func(t *testing.T, err error) { + t.Helper() + if !apierrors.IsTooManyRequests(err) { + t.Fatalf("expected TooManyRequests, got %v", err) + } + }, + }, + { + name: "maps create conflict to already exists", + err: withCoderMessage( + codersdk.NewTestError(http.StatusConflict, http.MethodPost, "https://coder.example.com"), + "workspace already exists", + ), + assertMapping: func(t *testing.T, err error) { + t.Helper() + if !apierrors.IsAlreadyExists(err) { + t.Fatalf("expected AlreadyExists, got %v", err) + } + }, + }, + { + name: "maps update conflict to conflict", + err: withCoderMessage( + codersdk.NewTestError(http.StatusConflict, http.MethodPatch, "https://coder.example.com"), + "resource version mismatch", + ), + assertMapping: func(t *testing.T, err error) { + t.Helper() + if !apierrors.IsConflict(err) { + t.Fatalf("expected Conflict, got %v", err) + } + }, + }, + { + name: "maps coder internal errors", + err: codersdk.NewTestError(http.StatusInternalServerError, http.MethodGet, "https://coder.example.com"), + assertMapping: func(t *testing.T, err error) { + t.Helper() + if !apierrors.IsInternalError(err) { + t.Fatalf("expected InternalError, got %v", err) + } + }, + }, + { + name: "maps generic errors to internal", + err: errors.New("boom"), + assertMapping: func(t *testing.T, err error) { + t.Helper() + if !apierrors.IsInternalError(err) { + t.Fatalf("expected InternalError, got %v", err) + } + }, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + mappedErr := MapCoderError(testCase.err, resource, name) + testCase.assertMapping(t, mappedErr) + }) + } +} + +func TestMapCoderErrorAssertions(t *testing.T) { + t.Parallel() + + resource := aggregationv1alpha1.Resource("coderworkspaces") + coderErr := codersdk.NewTestError(http.StatusNotFound, http.MethodGet, "https://coder.example.com") + + tests := []struct { + name string + err error + resource schema.GroupResource + resourceName string + wantErrContains string + }{ + { + name: "rejects nil error", + err: nil, + resource: resource, + resourceName: "acme.alice.dev", + wantErrContains: "assertion failed: error must not be nil", + }, + { + name: "rejects empty resource", + err: coderErr, + resource: schema.GroupResource{}, + resourceName: "acme.alice.dev", + wantErrContains: "assertion failed: resource must not be empty", + }, + { + name: "rejects empty resource name", + err: coderErr, + resource: resource, + resourceName: "", + wantErrContains: "assertion failed: resource name must not be empty", + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + err := MapCoderError(testCase.err, testCase.resource, testCase.resourceName) + if err == nil { + t.Fatalf("expected error containing %q, got nil", testCase.wantErrContains) + } + if !strings.Contains(err.Error(), testCase.wantErrContains) { + t.Fatalf("expected error containing %q, got %q", testCase.wantErrContains, err.Error()) + } + }) + } +} + +func withCoderMessage(err *codersdk.Error, message string) *codersdk.Error { + if err == nil { + panic("assertion failed: coder error must not be nil") + } + + err.Message = message + + return err +} diff --git a/internal/aggregated/coder/names.go b/internal/aggregated/coder/names.go new file mode 100644 index 00000000..e3cd05b7 --- /dev/null +++ b/internal/aggregated/coder/names.go @@ -0,0 +1,95 @@ +package coder + +import ( + "fmt" + "strings" +) + +const nameSeparator = "." + +// ParseTemplateName splits "." into organization and template names. +func ParseTemplateName(name string) (org, template string, err error) { + segments, err := parseNameSegments(name, 2, "template") + if err != nil { + return "", "", err + } + + return segments[0], segments[1], nil +} + +// ParseWorkspaceName splits ".." into organization, user, and workspace names. +func ParseWorkspaceName(name string) (org, user, workspace string, err error) { + segments, err := parseNameSegments(name, 3, "workspace") + if err != nil { + return "", "", "", err + } + + return segments[0], segments[1], segments[2], nil +} + +// BuildTemplateName constructs ".". +func BuildTemplateName(org, template string) string { + assertNameSegment("organization", org) + assertNameSegment("template", template) + + return org + nameSeparator + template +} + +// BuildWorkspaceName constructs "..". +func BuildWorkspaceName(org, user, workspace string) string { + assertNameSegment("organization", org) + assertNameSegment("user", user) + assertNameSegment("workspace", workspace) + + return org + nameSeparator + user + nameSeparator + workspace +} + +func parseNameSegments(name string, expectedSegments int, objectType string) ([]string, error) { + if name == "" { + return nil, fmt.Errorf("invalid %s name: name must not be empty", objectType) + } + + expectedSeparatorCount := expectedSegments - 1 + if strings.Count(name, nameSeparator) != expectedSeparatorCount { + return nil, fmt.Errorf( + "invalid %s name %q: expected %d separators (%q)", + objectType, + name, + expectedSeparatorCount, + nameSeparator, + ) + } + + segments := strings.Split(name, nameSeparator) + if len(segments) != expectedSegments { + return nil, fmt.Errorf( + "assertion failed: parsed %s name %q into %d segments; expected %d", + objectType, + name, + len(segments), + expectedSegments, + ) + } + + for segmentIndex, segment := range segments { + if segment == "" { + return nil, fmt.Errorf( + "invalid %s name %q: segment %d must not be empty", + objectType, + name, + segmentIndex, + ) + } + } + + return segments, nil +} + +func assertNameSegment(segmentType, value string) { + if value == "" { + panic(fmt.Sprintf("assertion failed: %s must not be empty", segmentType)) + } + if strings.Contains(value, nameSeparator) { + panic(fmt.Sprintf("assertion failed: %s must not contain %q", segmentType, nameSeparator)) + } +} diff --git a/internal/aggregated/coder/names_test.go b/internal/aggregated/coder/names_test.go new file mode 100644 index 00000000..eb349a7b --- /dev/null +++ b/internal/aggregated/coder/names_test.go @@ -0,0 +1,174 @@ +package coder + +import ( + "strings" + "testing" +) + +func TestParseTemplateName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + wantOrg string + wantTmpl string + wantError bool + }{ + {name: "valid", input: "acme.starter", wantOrg: "acme", wantTmpl: "starter"}, + {name: "empty input", input: "", wantError: true}, + {name: "missing separator", input: "acme", wantError: true}, + {name: "too many separators", input: "acme.team.starter", wantError: true}, + {name: "empty organization", input: ".starter", wantError: true}, + {name: "empty template", input: "acme.", wantError: true}, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + org, template, err := ParseTemplateName(testCase.input) + if testCase.wantError { + if err == nil { + t.Fatalf("expected error for input %q", testCase.input) + } + return + } + if err != nil { + t.Fatalf("unexpected error for input %q: %v", testCase.input, err) + } + if org != testCase.wantOrg { + t.Fatalf("expected organization %q, got %q", testCase.wantOrg, org) + } + if template != testCase.wantTmpl { + t.Fatalf("expected template %q, got %q", testCase.wantTmpl, template) + } + }) + } +} + +func TestParseWorkspaceName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + wantOrg string + wantUser string + wantWorkspace string + wantError bool + }{ + {name: "valid", input: "acme.alice.dev", wantOrg: "acme", wantUser: "alice", wantWorkspace: "dev"}, + {name: "empty input", input: "", wantError: true}, + {name: "missing separator", input: "acme", wantError: true}, + {name: "too few separators", input: "acme.alice", wantError: true}, + {name: "too many separators", input: "acme.alice.team.dev", wantError: true}, + {name: "empty organization", input: ".alice.dev", wantError: true}, + {name: "empty user", input: "acme..dev", wantError: true}, + {name: "empty workspace", input: "acme.alice.", wantError: true}, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + org, user, workspace, err := ParseWorkspaceName(testCase.input) + if testCase.wantError { + if err == nil { + t.Fatalf("expected error for input %q", testCase.input) + } + return + } + if err != nil { + t.Fatalf("unexpected error for input %q: %v", testCase.input, err) + } + if org != testCase.wantOrg { + t.Fatalf("expected organization %q, got %q", testCase.wantOrg, org) + } + if user != testCase.wantUser { + t.Fatalf("expected user %q, got %q", testCase.wantUser, user) + } + if workspace != testCase.wantWorkspace { + t.Fatalf("expected workspace %q, got %q", testCase.wantWorkspace, workspace) + } + }) + } +} + +func TestBuildTemplateName(t *testing.T) { + t.Parallel() + + if got, want := BuildTemplateName("acme", "starter"), "acme.starter"; got != want { + t.Fatalf("expected %q, got %q", want, got) + } +} + +func TestBuildWorkspaceName(t *testing.T) { + t.Parallel() + + if got, want := BuildWorkspaceName("acme", "alice", "dev"), "acme.alice.dev"; got != want { + t.Fatalf("expected %q, got %q", want, got) + } +} + +func TestBuildNamePanicsForInvalidSegments(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + fn func() + }{ + { + name: "empty org in template", + fn: func() { + _ = BuildTemplateName("", "starter") + }, + }, + { + name: "dot in template segment", + fn: func() { + _ = BuildTemplateName("acme", "starter.v2") + }, + }, + { + name: "empty workspace segment", + fn: func() { + _ = BuildWorkspaceName("acme", "alice", "") + }, + }, + { + name: "dot in user segment", + fn: func() { + _ = BuildWorkspaceName("acme", "alice.dev", "workspace") + }, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + expectAssertionPanic(t, testCase.fn) + }) + } +} + +func expectAssertionPanic(t *testing.T, fn func()) { + t.Helper() + + defer func() { + recovered := recover() + if recovered == nil { + t.Fatal("expected panic, got nil") + } + + message, ok := recovered.(string) + if !ok { + t.Fatalf("expected panic string, got %T (%v)", recovered, recovered) + } + if !strings.HasPrefix(message, "assertion failed:") { + t.Fatalf("expected assertion panic, got %q", message) + } + }() + + fn() +} diff --git a/internal/aggregated/coder/provider.go b/internal/aggregated/coder/provider.go new file mode 100644 index 00000000..db5c6415 --- /dev/null +++ b/internal/aggregated/coder/provider.go @@ -0,0 +1,73 @@ +package coder + +import ( + "context" + "fmt" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + + "github.com/coder/coder/v2/codersdk" +) + +// ClientProvider resolves a Coder SDK client for a Kubernetes request namespace. +type ClientProvider interface { + ClientForNamespace(ctx context.Context, namespace string) (*codersdk.Client, error) +} + +// StaticClientProvider returns one static client, optionally restricted to one namespace. +type StaticClientProvider struct { + Client *codersdk.Client + Namespace string // If non-empty, only this namespace is allowed. +} + +var _ ClientProvider = (*StaticClientProvider)(nil) + +// ClientForNamespace returns the static client. +func (p *StaticClientProvider) ClientForNamespace(ctx context.Context, namespace string) (*codersdk.Client, error) { + if p == nil { + return nil, fmt.Errorf("assertion failed: static client provider must not be nil") + } + if ctx == nil { + return nil, fmt.Errorf("assertion failed: context must not be nil") + } + if p.Client == nil { + return nil, fmt.Errorf("assertion failed: static client provider client must not be nil") + } + if p.Namespace == "" { + return nil, apierrors.NewServiceUnavailable( + "static coder client provider is not namespace-pinned; configure --coder-namespace", + ) + } + if namespace == "" { + namespace = p.Namespace + } + if namespace != p.Namespace { + return nil, apierrors.NewBadRequest( + fmt.Sprintf( + "namespace %q is not served by this aggregated API server (configured for %q)", + namespace, + p.Namespace, + ), + ) + } + + return p.Client, nil +} + +// NewStaticClientProvider creates a StaticClientProvider from cfg and optional namespace restriction. +func NewStaticClientProvider(cfg Config, namespace string) (*StaticClientProvider, error) { + client, err := NewSDKClient(cfg) + if err != nil { + return nil, fmt.Errorf("new SDK client: %w", err) + } + + provider := &StaticClientProvider{ + Client: client, + Namespace: namespace, + } + if provider.Client == nil { + return nil, fmt.Errorf("assertion failed: static client provider client is nil after successful construction") + } + + return provider, nil +} diff --git a/internal/aggregated/coder/provider_test.go b/internal/aggregated/coder/provider_test.go new file mode 100644 index 00000000..f81246e9 --- /dev/null +++ b/internal/aggregated/coder/provider_test.go @@ -0,0 +1,215 @@ +package coder + +import ( + "context" + "strings" + "testing" + + apierrors "k8s.io/apimachinery/pkg/api/errors" +) + +func TestStaticClientProviderClientForNamespace(t *testing.T) { + t.Parallel() + + client, err := NewSDKClient(Config{ + CoderURL: mustParseURL(t, "https://coder.example.com"), + SessionToken: "session-token", + }) + if err != nil { + t.Fatalf("create SDK client: %v", err) + } + + provider := &StaticClientProvider{Client: client, Namespace: "default"} + resolvedClient, err := provider.ClientForNamespace(context.Background(), "default") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if resolvedClient != client { + t.Fatalf("expected provider to return static client %p, got %p", client, resolvedClient) + } +} + +func TestStaticClientProviderClientForNamespaceAssertions(t *testing.T) { + t.Parallel() + + validClient, err := NewSDKClient(Config{ + CoderURL: mustParseURL(t, "https://coder.example.com"), + SessionToken: "session-token", + }) + if err != nil { + t.Fatalf("create SDK client: %v", err) + } + + tests := []struct { + name string + provider *StaticClientProvider + ctx context.Context + namespace string + wantErrContains string + }{ + { + name: "rejects nil provider", + provider: nil, + ctx: context.Background(), + namespace: "default", + wantErrContains: "assertion failed: static client provider must not be nil", + }, + { + name: "rejects nil context", + provider: &StaticClientProvider{Client: validClient}, + ctx: nil, + namespace: "default", + wantErrContains: "assertion failed: context must not be nil", + }, + { + name: "rejects nil client", + provider: &StaticClientProvider{}, + ctx: context.Background(), + namespace: "default", + wantErrContains: "assertion failed: static client provider client must not be nil", + }, + { + name: "rejects unpinned provider", + provider: &StaticClientProvider{Client: validClient}, + ctx: context.Background(), + namespace: "default", + wantErrContains: "static coder client provider is not namespace-pinned; configure --coder-namespace", + }, + } + + for _, testCase := range tests { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + _, err := testCase.provider.ClientForNamespace(testCase.ctx, testCase.namespace) + if err == nil { + t.Fatalf("expected error containing %q, got nil", testCase.wantErrContains) + } + if !strings.Contains(err.Error(), testCase.wantErrContains) { + t.Fatalf("expected error containing %q, got %q", testCase.wantErrContains, err.Error()) + } + }) + } +} + +func TestStaticClientProviderClientForNamespaceNamespaceRestriction(t *testing.T) { + t.Parallel() + + client, err := NewSDKClient(Config{ + CoderURL: mustParseURL(t, "https://coder.example.com"), + SessionToken: "session-token", + }) + if err != nil { + t.Fatalf("create SDK client: %v", err) + } + + provider := &StaticClientProvider{ + Client: client, + Namespace: "control-plane", + } + + resolvedClient, err := provider.ClientForNamespace(context.Background(), "control-plane") + if err != nil { + t.Fatalf("expected no error for matching namespace, got %v", err) + } + if resolvedClient != client { + t.Fatalf("expected provider to return static client %p, got %p", client, resolvedClient) + } + + _, err = provider.ClientForNamespace(context.Background(), "default") + if err == nil { + t.Fatal("expected namespace mismatch to fail") + } + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest for namespace mismatch, got %v", err) + } + wantErrContains := "namespace \"default\" is not served by this aggregated API server (configured for \"control-plane\")" + if !strings.Contains(err.Error(), wantErrContains) { + t.Fatalf("expected error containing %q, got %q", wantErrContains, err.Error()) + } +} + +func TestStaticClientProviderClientForNamespaceAllowsClusterScopedListNamespace(t *testing.T) { + t.Parallel() + + client, err := NewSDKClient(Config{ + CoderURL: mustParseURL(t, "https://coder.example.com"), + SessionToken: "session-token", + }) + if err != nil { + t.Fatalf("create SDK client: %v", err) + } + + provider := &StaticClientProvider{ + Client: client, + Namespace: "control-plane", + } + + resolvedClient, err := provider.ClientForNamespace(context.Background(), "") + if err != nil { + t.Fatalf("expected no error for empty namespace when provider is pinned, got %v", err) + } + if resolvedClient != client { + t.Fatalf("expected provider to return static client %p, got %p", client, resolvedClient) + } +} + +func TestNewStaticClientProvider(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg Config + namespace string + wantErrContains string + }{ + { + name: "success", + cfg: Config{ + CoderURL: mustParseURL(t, "https://coder.example.com"), + SessionToken: "session-token", + }, + namespace: "control-plane", + }, + { + name: "surfaces SDK config assertion", + cfg: Config{ + SessionToken: "session-token", + }, + namespace: "control-plane", + wantErrContains: "new SDK client: assertion failed: coder URL must not be nil", + }, + } + + for _, testCase := range tests { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + provider, err := NewStaticClientProvider(testCase.cfg, testCase.namespace) + if testCase.wantErrContains != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", testCase.wantErrContains) + } + if !strings.Contains(err.Error(), testCase.wantErrContains) { + t.Fatalf("expected error containing %q, got %q", testCase.wantErrContains, err.Error()) + } + return + } + + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if provider == nil { + t.Fatal("expected non-nil provider") + } + if provider.Client == nil { + t.Fatal("expected non-nil provider client") + } + if provider.Namespace != testCase.namespace { + t.Fatalf("expected provider namespace %q, got %q", testCase.namespace, provider.Namespace) + } + }) + } +} diff --git a/internal/aggregated/convert/template.go b/internal/aggregated/convert/template.go new file mode 100644 index 00000000..be93f4e4 --- /dev/null +++ b/internal/aggregated/convert/template.go @@ -0,0 +1,74 @@ +// Package convert maps codersdk models to aggregated API resources and request payloads. +package convert + +import ( + "fmt" + "strconv" + + aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" + "github.com/coder/coder-k8s/internal/aggregated/coder" + "github.com/coder/coder/v2/codersdk" + "github.com/google/uuid" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" +) + +// TemplateToK8s converts a codersdk.Template to an aggregated API CoderTemplate. +func TemplateToK8s(namespace string, t codersdk.Template) *aggregationv1alpha1.CoderTemplate { + if namespace == "" { + panic("assertion failed: namespace must not be empty") + } + + updatedAt := metav1.NewTime(t.UpdatedAt) + + return &aggregationv1alpha1.CoderTemplate{ + TypeMeta: metav1.TypeMeta{ + Kind: "CoderTemplate", + APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), + }, + ObjectMeta: metav1.ObjectMeta{ + Name: coder.BuildTemplateName(t.OrganizationName, t.Name), + Namespace: namespace, + UID: types.UID(t.ID.String()), + ResourceVersion: strconv.FormatInt(t.UpdatedAt.UnixNano(), 10), + CreationTimestamp: metav1.NewTime(t.CreatedAt), + }, + Spec: aggregationv1alpha1.CoderTemplateSpec{ + Organization: t.OrganizationName, + VersionID: t.ActiveVersionID.String(), + DisplayName: t.DisplayName, + Description: t.Description, + Icon: t.Icon, + }, + Status: aggregationv1alpha1.CoderTemplateStatus{ + ID: t.ID.String(), + OrganizationName: t.OrganizationName, + ActiveVersionID: t.ActiveVersionID.String(), + Deprecated: t.Deprecated, + UpdatedAt: &updatedAt, + }, + } +} + +// TemplateCreateRequestFromK8s builds a codersdk.CreateTemplateRequest from a K8s CoderTemplate. +func TemplateCreateRequestFromK8s(obj *aggregationv1alpha1.CoderTemplate, templateName string) (codersdk.CreateTemplateRequest, error) { + if obj == nil { + return codersdk.CreateTemplateRequest{}, fmt.Errorf("assertion failed: template object must not be nil") + } + if templateName == "" { + return codersdk.CreateTemplateRequest{}, fmt.Errorf("assertion failed: template name must not be empty") + } + + versionID, err := uuid.Parse(obj.Spec.VersionID) + if err != nil { + return codersdk.CreateTemplateRequest{}, fmt.Errorf("parse template spec.versionID %q: %w", obj.Spec.VersionID, err) + } + + return codersdk.CreateTemplateRequest{ + Name: templateName, + VersionID: versionID, + DisplayName: obj.Spec.DisplayName, + Description: obj.Spec.Description, + Icon: obj.Spec.Icon, + }, nil +} diff --git a/internal/aggregated/convert/template_test.go b/internal/aggregated/convert/template_test.go new file mode 100644 index 00000000..6547b972 --- /dev/null +++ b/internal/aggregated/convert/template_test.go @@ -0,0 +1,136 @@ +package convert + +import ( + "strconv" + "strings" + "testing" + "time" + + aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" + "github.com/coder/coder/v2/codersdk" + "github.com/google/uuid" +) + +func TestTemplateToK8s(t *testing.T) { + t.Parallel() + + templateID := uuid.New() + activeVersionID := uuid.New() + createdAt := time.Date(2025, time.January, 2, 3, 4, 5, 0, time.UTC) + updatedAt := createdAt.Add(2 * time.Hour) + + template := codersdk.Template{ + ID: templateID, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + OrganizationName: "acme", + Name: "starter-template", + DisplayName: "Starter Template", + Description: "Base development template", + Icon: "/icons/starter.png", + ActiveVersionID: activeVersionID, + Deprecated: true, + } + + converted := TemplateToK8s("control-plane", template) + if converted == nil { + t.Fatal("expected non-nil converted template") + } + if converted.Name != "acme.starter-template" { + t.Fatalf("expected name acme.starter-template, got %q", converted.Name) + } + if converted.Namespace != "control-plane" { + t.Fatalf("expected namespace control-plane, got %q", converted.Namespace) + } + expectedResourceVersion := strconv.FormatInt(updatedAt.UnixNano(), 10) + if converted.ResourceVersion != expectedResourceVersion { + t.Fatalf( + "expected resource version %q from updated timestamp, got %q", + expectedResourceVersion, + converted.ResourceVersion, + ) + } + if converted.Spec.Organization != "acme" { + t.Fatalf("expected spec organization acme, got %q", converted.Spec.Organization) + } + if converted.Spec.VersionID != activeVersionID.String() { + t.Fatalf("expected spec version ID %q, got %q", activeVersionID.String(), converted.Spec.VersionID) + } + if converted.Spec.DisplayName != "Starter Template" { + t.Fatalf("expected spec display name Starter Template, got %q", converted.Spec.DisplayName) + } + if converted.Spec.Description != "Base development template" { + t.Fatalf("expected spec description Base development template, got %q", converted.Spec.Description) + } + if converted.Spec.Icon != "/icons/starter.png" { + t.Fatalf("expected spec icon /icons/starter.png, got %q", converted.Spec.Icon) + } + if converted.Status.ID != templateID.String() { + t.Fatalf("expected status ID %q, got %q", templateID.String(), converted.Status.ID) + } + if converted.Status.OrganizationName != "acme" { + t.Fatalf("expected status organization name acme, got %q", converted.Status.OrganizationName) + } + if converted.Status.ActiveVersionID != activeVersionID.String() { + t.Fatalf("expected status active version ID %q, got %q", activeVersionID.String(), converted.Status.ActiveVersionID) + } + if !converted.Status.Deprecated { + t.Fatal("expected status deprecated true") + } + if converted.Status.UpdatedAt == nil { + t.Fatal("expected status updatedAt to be set") + } + if !converted.Status.UpdatedAt.Time.Equal(updatedAt) { + t.Fatalf("expected status updatedAt %s, got %s", updatedAt, converted.Status.UpdatedAt.Time) + } +} + +func TestTemplateCreateRequestFromK8s(t *testing.T) { + t.Parallel() + + versionID := uuid.New() + obj := &aggregationv1alpha1.CoderTemplate{ + Spec: aggregationv1alpha1.CoderTemplateSpec{ + VersionID: versionID.String(), + DisplayName: "Starter Template", + Description: "Base development template", + Icon: "/icons/starter.png", + }, + } + + request, err := TemplateCreateRequestFromK8s(obj, "starter-template") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if request.Name != "starter-template" { + t.Fatalf("expected request name starter-template, got %q", request.Name) + } + if request.VersionID != versionID { + t.Fatalf("expected request version ID %q, got %q", versionID, request.VersionID) + } + if request.DisplayName != "Starter Template" { + t.Fatalf("expected request display name Starter Template, got %q", request.DisplayName) + } + if request.Description != "Base development template" { + t.Fatalf("expected request description Base development template, got %q", request.Description) + } + if request.Icon != "/icons/starter.png" { + t.Fatalf("expected request icon /icons/starter.png, got %q", request.Icon) + } +} + +func TestTemplateCreateRequestFromK8sRejectsInvalidVersionID(t *testing.T) { + t.Parallel() + + obj := &aggregationv1alpha1.CoderTemplate{ + Spec: aggregationv1alpha1.CoderTemplateSpec{VersionID: "not-a-uuid"}, + } + + _, err := TemplateCreateRequestFromK8s(obj, "starter-template") + if err == nil { + t.Fatal("expected error for invalid spec.versionID, got nil") + } + if !strings.Contains(err.Error(), "parse template spec.versionID") { + t.Fatalf("expected parse error, got %v", err) + } +} diff --git a/internal/aggregated/convert/workspace.go b/internal/aggregated/convert/workspace.go new file mode 100644 index 00000000..158c121e --- /dev/null +++ b/internal/aggregated/convert/workspace.go @@ -0,0 +1,108 @@ +package convert + +import ( + "fmt" + "strconv" + + aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" + "github.com/coder/coder-k8s/internal/aggregated/coder" + "github.com/coder/coder/v2/codersdk" + "github.com/google/uuid" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" +) + +// WorkspaceToK8s converts a codersdk.Workspace to an aggregated API CoderWorkspace. +func WorkspaceToK8s(namespace string, w codersdk.Workspace) *aggregationv1alpha1.CoderWorkspace { + if namespace == "" { + panic("assertion failed: namespace must not be empty") + } + + var autoShutdown *metav1.Time + if w.LatestBuild.Deadline.Valid && !w.LatestBuild.Deadline.Time.IsZero() { + autoShutdownTime := metav1.NewTime(w.LatestBuild.Deadline.Time) + autoShutdown = &autoShutdownTime + } + lastUsedAt := metav1.NewTime(w.LastUsedAt) + + return &aggregationv1alpha1.CoderWorkspace{ + TypeMeta: metav1.TypeMeta{ + Kind: "CoderWorkspace", + APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), + }, + ObjectMeta: metav1.ObjectMeta{ + Name: coder.BuildWorkspaceName(w.OrganizationName, w.OwnerName, w.Name), + Namespace: namespace, + UID: types.UID(w.ID.String()), + ResourceVersion: strconv.FormatInt(w.UpdatedAt.UnixNano(), 10), + CreationTimestamp: metav1.NewTime(w.CreatedAt), + }, + Spec: aggregationv1alpha1.CoderWorkspaceSpec{ + Organization: w.OrganizationName, + TemplateName: w.TemplateName, + TemplateVersionID: w.LatestBuild.TemplateVersionID.String(), + Running: workspaceRunning(w), + TTLMillis: w.TTLMillis, + AutostartSchedule: w.AutostartSchedule, + }, + Status: aggregationv1alpha1.CoderWorkspaceStatus{ + ID: w.ID.String(), + OwnerName: w.OwnerName, + OrganizationName: w.OrganizationName, + TemplateName: w.TemplateName, + LatestBuildID: w.LatestBuild.ID.String(), + LatestBuildStatus: string(w.LatestBuild.Status), + AutoShutdown: autoShutdown, + LastUsedAt: &lastUsedAt, + }, + } +} + +func workspaceRunning(workspace codersdk.Workspace) bool { + if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart { + return false + } + + switch workspace.LatestBuild.Status { + case codersdk.WorkspaceStatusPending, codersdk.WorkspaceStatusStarting, codersdk.WorkspaceStatusRunning: + return true + default: + return false + } +} + +// WorkspaceCreateRequestFromK8s builds a codersdk.CreateWorkspaceRequest. +func WorkspaceCreateRequestFromK8s( + obj *aggregationv1alpha1.CoderWorkspace, + workspaceName string, + templateID uuid.UUID, +) (codersdk.CreateWorkspaceRequest, error) { + if obj == nil { + panic("assertion failed: workspace object must not be nil") + } + if workspaceName == "" { + panic("assertion failed: workspace name must not be empty") + } + if templateID == uuid.Nil { + panic("assertion failed: template ID must not be nil") + } + + request := codersdk.CreateWorkspaceRequest{ + Name: workspaceName, + TTLMillis: obj.Spec.TTLMillis, + AutostartSchedule: obj.Spec.AutostartSchedule, + } + + if obj.Spec.TemplateVersionID == "" { + request.TemplateID = templateID + return request, nil + } + + templateVersionID, err := uuid.Parse(obj.Spec.TemplateVersionID) + if err != nil { + return codersdk.CreateWorkspaceRequest{}, fmt.Errorf("invalid templateVersionID %q: %w", obj.Spec.TemplateVersionID, err) + } + + request.TemplateVersionID = templateVersionID + return request, nil +} diff --git a/internal/aggregated/convert/workspace_test.go b/internal/aggregated/convert/workspace_test.go new file mode 100644 index 00000000..caf8e133 --- /dev/null +++ b/internal/aggregated/convert/workspace_test.go @@ -0,0 +1,257 @@ +package convert + +import ( + "testing" + "time" + + aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" + "github.com/coder/coder/v2/codersdk" + "github.com/google/uuid" +) + +func TestWorkspaceToK8s(t *testing.T) { + t.Parallel() + + workspaceID := uuid.New() + buildID := uuid.New() + createdAt := time.Date(2025, time.February, 2, 3, 4, 5, 0, time.UTC) + updatedAt := createdAt.Add(4 * time.Hour) + lastUsedAt := createdAt.Add(3 * time.Hour) + autoShutdownAt := createdAt.Add(6 * time.Hour) + ttlMillis := int64(3600000) + autostartSchedule := "CRON_TZ=UTC 0 9 * * 1-5" + + workspace := codersdk.Workspace{ + ID: workspaceID, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + OwnerName: "alice", + OrganizationName: "acme", + TemplateName: "starter-template", + Name: "dev-workspace", + TTLMillis: &ttlMillis, + AutostartSchedule: &autostartSchedule, + LastUsedAt: lastUsedAt, + LatestBuild: codersdk.WorkspaceBuild{ + ID: buildID, + Transition: codersdk.WorkspaceTransitionStart, + Status: codersdk.WorkspaceStatusStarting, + Deadline: codersdk.NewNullTime(autoShutdownAt, true), + }, + } + + converted := WorkspaceToK8s("control-plane", workspace) + if converted == nil { + t.Fatal("expected non-nil converted workspace") + } + if converted.Name != "acme.alice.dev-workspace" { + t.Fatalf("expected name acme.alice.dev-workspace, got %q", converted.Name) + } + if converted.Namespace != "control-plane" { + t.Fatalf("expected namespace control-plane, got %q", converted.Namespace) + } + if converted.Spec.Organization != "acme" { + t.Fatalf("expected spec organization acme, got %q", converted.Spec.Organization) + } + if converted.Spec.TemplateName != "starter-template" { + t.Fatalf("expected spec template name starter-template, got %q", converted.Spec.TemplateName) + } + if !converted.Spec.Running { + t.Fatal("expected running=true when latest build transition is start") + } + if converted.Spec.TTLMillis == nil || *converted.Spec.TTLMillis != ttlMillis { + t.Fatalf("expected TTL millis %d, got %+v", ttlMillis, converted.Spec.TTLMillis) + } + if converted.Spec.AutostartSchedule == nil || *converted.Spec.AutostartSchedule != autostartSchedule { + t.Fatalf("expected autostart schedule %q, got %+v", autostartSchedule, converted.Spec.AutostartSchedule) + } + if converted.Status.ID != workspaceID.String() { + t.Fatalf("expected status ID %q, got %q", workspaceID.String(), converted.Status.ID) + } + if converted.Status.OwnerName != "alice" { + t.Fatalf("expected status owner name alice, got %q", converted.Status.OwnerName) + } + if converted.Status.OrganizationName != "acme" { + t.Fatalf("expected status organization name acme, got %q", converted.Status.OrganizationName) + } + if converted.Status.TemplateName != "starter-template" { + t.Fatalf("expected status template name starter-template, got %q", converted.Status.TemplateName) + } + if converted.Status.LatestBuildID != buildID.String() { + t.Fatalf("expected status latest build ID %q, got %q", buildID.String(), converted.Status.LatestBuildID) + } + if converted.Status.LatestBuildStatus != string(codersdk.WorkspaceStatusStarting) { + t.Fatalf("expected status latest build status %q, got %q", codersdk.WorkspaceStatusStarting, converted.Status.LatestBuildStatus) + } + if converted.Status.AutoShutdown == nil { + t.Fatal("expected status autoShutdown to be set") + } + if !converted.Status.AutoShutdown.Time.Equal(autoShutdownAt) { + t.Fatalf("expected status autoShutdown %s, got %s", autoShutdownAt, converted.Status.AutoShutdown.Time) + } + if converted.Status.LastUsedAt == nil { + t.Fatal("expected status lastUsedAt to be set") + } + if !converted.Status.LastUsedAt.Time.Equal(lastUsedAt) { + t.Fatalf("expected status lastUsedAt %s, got %s", lastUsedAt, converted.Status.LastUsedAt.Time) + } +} + +func TestWorkspaceToK8sRunningStateFromTransitionAndStatus(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + transition codersdk.WorkspaceTransition + status codersdk.WorkspaceStatus + running bool + }{ + { + name: "start pending", + transition: codersdk.WorkspaceTransitionStart, + status: codersdk.WorkspaceStatusPending, + running: true, + }, + { + name: "start starting", + transition: codersdk.WorkspaceTransitionStart, + status: codersdk.WorkspaceStatusStarting, + running: true, + }, + { + name: "start running", + transition: codersdk.WorkspaceTransitionStart, + status: codersdk.WorkspaceStatusRunning, + running: true, + }, + { + name: "start failed", + transition: codersdk.WorkspaceTransitionStart, + status: codersdk.WorkspaceStatusFailed, + running: false, + }, + { + name: "start canceled", + transition: codersdk.WorkspaceTransitionStart, + status: codersdk.WorkspaceStatusCanceled, + running: false, + }, + { + name: "stop running", + transition: codersdk.WorkspaceTransitionStop, + status: codersdk.WorkspaceStatusRunning, + running: false, + }, + } + + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + now := time.Date(2025, time.February, 2, 3, 4, 5, 0, time.UTC) + workspace := codersdk.Workspace{ + ID: uuid.New(), + CreatedAt: now, + UpdatedAt: now, + OwnerName: "alice", + OrganizationName: "acme", + TemplateName: "starter-template", + Name: "dev-workspace", + LastUsedAt: now, + LatestBuild: codersdk.WorkspaceBuild{ + ID: uuid.New(), + Transition: testCase.transition, + Status: testCase.status, + }, + } + + converted := WorkspaceToK8s("control-plane", workspace) + if converted.Spec.Running != testCase.running { + t.Fatalf( + "expected running=%t for transition=%q status=%q, got %t", + testCase.running, + testCase.transition, + testCase.status, + converted.Spec.Running, + ) + } + }) + } +} + +func TestWorkspaceCreateRequestFromK8s(t *testing.T) { + t.Parallel() + + templateID := uuid.New() + ttlMillis := int64(3600000) + autostartSchedule := "CRON_TZ=UTC 0 9 * * 1-5" + + obj := &aggregationv1alpha1.CoderWorkspace{ + Spec: aggregationv1alpha1.CoderWorkspaceSpec{ + TTLMillis: &ttlMillis, + AutostartSchedule: &autostartSchedule, + }, + } + + request, err := WorkspaceCreateRequestFromK8s(obj, "dev-workspace", templateID) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if request.Name != "dev-workspace" { + t.Fatalf("expected request name dev-workspace, got %q", request.Name) + } + if request.TemplateID != templateID { + t.Fatalf("expected request template ID %q, got %q", templateID, request.TemplateID) + } + if request.TemplateVersionID != uuid.Nil { + t.Fatalf("expected request template version ID %q, got %q", uuid.Nil, request.TemplateVersionID) + } + if request.TTLMillis == nil || *request.TTLMillis != ttlMillis { + t.Fatalf("expected request TTL millis %d, got %+v", ttlMillis, request.TTLMillis) + } + if request.AutostartSchedule == nil || *request.AutostartSchedule != autostartSchedule { + t.Fatalf("expected request autostart schedule %q, got %+v", autostartSchedule, request.AutostartSchedule) + } +} + +func TestWorkspaceCreateRequestFromK8sUsesTemplateVersionID(t *testing.T) { + t.Parallel() + + templateID := uuid.New() + templateVersionID := uuid.New() + + obj := &aggregationv1alpha1.CoderWorkspace{ + Spec: aggregationv1alpha1.CoderWorkspaceSpec{ + TemplateVersionID: templateVersionID.String(), + }, + } + + request, err := WorkspaceCreateRequestFromK8s(obj, "dev-workspace", templateID) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if request.TemplateVersionID != templateVersionID { + t.Fatalf("expected request template version ID %q, got %q", templateVersionID, request.TemplateVersionID) + } + if request.TemplateID != uuid.Nil { + t.Fatalf("expected request template ID %q, got %q", uuid.Nil, request.TemplateID) + } +} + +func TestWorkspaceCreateRequestFromK8sReturnsErrorForInvalidTemplateVersionID(t *testing.T) { + t.Parallel() + + templateID := uuid.New() + + obj := &aggregationv1alpha1.CoderWorkspace{ + Spec: aggregationv1alpha1.CoderWorkspaceSpec{ + TemplateVersionID: "not-a-uuid", + }, + } + + _, err := WorkspaceCreateRequestFromK8s(obj, "dev-workspace", templateID) + if err == nil { + t.Fatal("expected error for invalid templateVersionID") + } +} diff --git a/internal/aggregated/storage/doc.go b/internal/aggregated/storage/doc.go new file mode 100644 index 00000000..37855a7d --- /dev/null +++ b/internal/aggregated/storage/doc.go @@ -0,0 +1,12 @@ +// Package storage implements codersdk-backed REST storage for the aggregated API +// server's CoderWorkspace and CoderTemplate resources. +// +// v1 Semantics: +// - Resources are namespace-scoped; the namespace represents the CoderControlPlane namespace. +// - Template object names follow the format ".". +// - Workspace object names follow "..". +// - The dot separator works because Coder names are alphanumeric-with-hyphens (no dots), +// while Kubernetes object names allow dots (DNS-1123 subdomains). +// - A single admin session token is used for all API calls (no per-request impersonation in v1). +// - Storage resolves the backing codersdk.Client via a ClientProvider interface. +package storage diff --git a/internal/aggregated/storage/errors.go b/internal/aggregated/storage/errors.go new file mode 100644 index 00000000..5058d161 --- /dev/null +++ b/internal/aggregated/storage/errors.go @@ -0,0 +1,20 @@ +package storage + +import ( + "errors" + + apierrors "k8s.io/apimachinery/pkg/api/errors" +) + +func wrapClientError(err error) error { + if err == nil { + return nil + } + + var statusErr *apierrors.StatusError + if errors.As(err, &statusErr) { + return statusErr + } + + return apierrors.NewInternalError(err) +} diff --git a/internal/aggregated/storage/helpers.go b/internal/aggregated/storage/helpers.go deleted file mode 100644 index 1f07e7c9..00000000 --- a/internal/aggregated/storage/helpers.go +++ /dev/null @@ -1,43 +0,0 @@ -package storage - -import ( - "context" - "fmt" - "strconv" - - apierrors "k8s.io/apimachinery/pkg/api/errors" - genericapirequest "k8s.io/apiserver/pkg/endpoints/request" -) - -func resolveWriteNamespace(ctx context.Context, objectNamespace string) (string, error) { - requestNamespace := genericapirequest.NamespaceValue(ctx) - if requestNamespace == "" && objectNamespace == "" { - return "", apierrors.NewBadRequest("namespace is required") - } - if requestNamespace == "" { - return objectNamespace, nil - } - if objectNamespace == "" { - return requestNamespace, nil - } - if requestNamespace != objectNamespace { - return "", apierrors.NewBadRequest(fmt.Sprintf("request namespace %q does not match object namespace %q", requestNamespace, objectNamespace)) - } - return requestNamespace, nil -} - -func incrementResourceVersion(resourceVersion string) (string, error) { - if resourceVersion == "" { - return "1", nil - } - - version, err := strconv.ParseInt(resourceVersion, 10, 64) - if err != nil { - return "", fmt.Errorf("assertion failed: invalid resourceVersion %q: %w", resourceVersion, err) - } - if version < 0 { - return "", fmt.Errorf("assertion failed: resourceVersion must not be negative: %d", version) - } - - return strconv.FormatInt(version+1, 10), nil -} diff --git a/internal/aggregated/storage/storage_test.go b/internal/aggregated/storage/storage_test.go index a3b242f1..47b1fd29 100644 --- a/internal/aggregated/storage/storage_test.go +++ b/internal/aggregated/storage/storage_test.go @@ -1,116 +1,2058 @@ -// Package storage provides hardcoded in-memory storage implementations for aggregated API resources. package storage import ( "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "sort" + "strings" + "sync" "testing" + "time" + "github.com/google/uuid" apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" genericapirequest "k8s.io/apiserver/pkg/endpoints/request" + "k8s.io/apiserver/pkg/registry/rest" aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" + "github.com/coder/coder-k8s/internal/aggregated/coder" + "github.com/coder/coder/v2/codersdk" ) -func TestWorkspaceStorageList(t *testing.T) { - t.Helper() +func TestTemplateStorageCRUDWithCoderSDK(t *testing.T) { + t.Parallel() + + server, state := newMockCoderServer(t) + defer server.Close() + + templateStorage := NewTemplateStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + listObj, err := templateStorage.List(ctx, nil) + if err != nil { + t.Fatalf("expected template list to succeed: %v", err) + } + + list, ok := listObj.(*aggregationv1alpha1.CoderTemplateList) + if !ok { + t.Fatalf("expected *CoderTemplateList, got %T", listObj) + } + if len(list.Items) != 1 { + t.Fatalf("expected one template in list, got %d", len(list.Items)) + } + if list.Items[0].Name != "acme.starter-template" { + t.Fatalf("expected template name acme.starter-template, got %q", list.Items[0].Name) + } + + obj, err := templateStorage.Get(ctx, "acme.starter-template", nil) + if err != nil { + t.Fatalf("expected template get to succeed: %v", err) + } + + template, ok := obj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + t.Fatalf("expected *CoderTemplate, got %T", obj) + } + if template.Spec.Organization != "acme" { + t.Fatalf("expected organization acme, got %q", template.Spec.Organization) + } + + versionID := uuid.New() + createObj := &aggregationv1alpha1.CoderTemplate{ + ObjectMeta: metav1.ObjectMeta{Name: "acme.ops-template"}, + Spec: aggregationv1alpha1.CoderTemplateSpec{ + Organization: "acme", + VersionID: versionID.String(), + DisplayName: "Ops Template", + Description: "Operations tooling", + Icon: "/icons/ops.png", + }, + } + + createdObj, err := templateStorage.Create(ctx, createObj, rest.ValidateAllObjectFunc, nil) + if err != nil { + t.Fatalf("expected template create to succeed: %v", err) + } + + createdTemplate, ok := createdObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + t.Fatalf("expected *CoderTemplate from create, got %T", createdObj) + } + if createdTemplate.Name != "acme.ops-template" { + t.Fatalf("expected created template name acme.ops-template, got %q", createdTemplate.Name) + } + if createdTemplate.Spec.DisplayName != "Ops Template" { + t.Fatalf("expected created display name Ops Template, got %q", createdTemplate.Spec.DisplayName) + } + + if !state.hasTemplate("acme", "ops-template") { + t.Fatal("expected template to be persisted in mock server state") + } + + _, deleted, err := templateStorage.Delete(ctx, "acme.ops-template", rest.ValidateAllObjectFunc, nil) + if err != nil { + t.Fatalf("expected template delete to succeed: %v", err) + } + if !deleted { + t.Fatal("expected delete to report deleted=true") + } + + _, err = templateStorage.Get(ctx, "acme.ops-template", nil) + if !apierrors.IsNotFound(err) { + t.Fatalf("expected NotFound after delete, got %v", err) + } +} + +func TestTemplateStorageListAllowsAllNamespacesRequest(t *testing.T) { + t.Parallel() + + server, _ := newMockCoderServer(t) + defer server.Close() + + templateStorage := NewTemplateStorage(newTestClientProvider(t, server.URL)) + + listObj, err := templateStorage.List(context.Background(), nil) + if err != nil { + t.Fatalf("expected all-namespaces list to succeed, got %v", err) + } + list, ok := listObj.(*aggregationv1alpha1.CoderTemplateList) + if !ok { + t.Fatalf("expected *CoderTemplateList, got %T", listObj) + } + if len(list.Items) == 0 { + t.Fatal("expected at least one template in list") + } +} + +func TestTemplateStorageListPreservesProviderStatusErrors(t *testing.T) { + t.Parallel() + + server, _ := newMockCoderServer(t) + defer server.Close() + + parsedURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("parse mock server URL %q: %v", server.URL, err) + } + client := codersdk.New(parsedURL) + client.SetSessionToken("test-session-token") + + templateStorage := NewTemplateStorage(&coder.StaticClientProvider{ + Client: client, + Namespace: "control-plane", + }) + + _, err = templateStorage.List(namespacedContext("other-namespace"), nil) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest from provider namespace restriction, got %v", err) + } + assertTopLevelStatusError(t, err) +} + +func TestWrapClientErrorReturnsTopLevelStatusError(t *testing.T) { + t.Parallel() + + statusErr := apierrors.NewBadRequest("provider namespace mismatch") + wrappedErr := fmt.Errorf("resolve codersdk client for namespace %q: %w", "control-plane", statusErr) + + wrappedClientErr := wrapClientError(wrappedErr) + if !apierrors.IsBadRequest(wrappedClientErr) { + t.Fatalf("expected BadRequest from wrapped status error, got %v", wrappedClientErr) + } + + assertTopLevelStatusError(t, wrappedClientErr) + + var unwrappedStatusErr *apierrors.StatusError + if !errors.As(wrappedClientErr, &unwrappedStatusErr) { + t.Fatalf("expected *apierrors.StatusError in wrapped client error chain, got %T", wrappedClientErr) + } + if unwrappedStatusErr != statusErr { + t.Fatalf("expected wrapClientError to return original status error pointer") + } +} + +func TestTemplateStorageUpdateReturnsCurrentBackendObjectForLegacyRunningField(t *testing.T) { + t.Parallel() + + server, _ := newMockCoderServer(t) + defer server.Close() + + templateStorage := NewTemplateStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := templateStorage.Get(ctx, "acme.starter-template", nil) + if err != nil { + t.Fatalf("expected template get to succeed: %v", err) + } + + currentTemplate, ok := currentObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + t.Fatalf("expected *CoderTemplate from get, got %T", currentObj) + } + if currentTemplate.ResourceVersion == "" { + t.Fatal("expected current template resourceVersion to be populated") + } + + desiredTemplate := currentTemplate.DeepCopy() + desiredTemplate.Spec.Running = !currentTemplate.Spec.Running + + updatedObj, created, err := templateStorage.Update( + ctx, + desiredTemplate.Name, + testUpdatedObjectInfo{obj: desiredTemplate}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if err != nil { + t.Fatalf("expected template update to succeed: %v", err) + } + if created { + t.Fatal("expected update created=false") + } + + updatedTemplate, ok := updatedObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + t.Fatalf("expected *CoderTemplate from update, got %T", updatedObj) + } + if updatedTemplate.Spec.Running != currentTemplate.Spec.Running { + t.Fatalf("expected update response running=%t from current backend object, got %t", currentTemplate.Spec.Running, updatedTemplate.Spec.Running) + } + if updatedTemplate.Name != currentTemplate.Name { + t.Fatalf("expected updated name %q, got %q", currentTemplate.Name, updatedTemplate.Name) + } +} + +func TestTemplateStorageUpdateAllowsEmptyVersionIDWhenTogglingRunning(t *testing.T) { + t.Parallel() + + server, _ := newMockCoderServer(t) + defer server.Close() + + templateStorage := NewTemplateStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := templateStorage.Get(ctx, "acme.starter-template", nil) + if err != nil { + t.Fatalf("expected template get to succeed: %v", err) + } + + currentTemplate, ok := currentObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + t.Fatalf("expected *CoderTemplate from get, got %T", currentObj) + } + if currentTemplate.Spec.VersionID == "" { + t.Fatal("expected current template spec.versionID to be populated") + } + + desiredTemplate := currentTemplate.DeepCopy() + desiredTemplate.Spec.Running = !currentTemplate.Spec.Running + desiredTemplate.Spec.VersionID = "" + + updatedObj, created, err := templateStorage.Update( + ctx, + desiredTemplate.Name, + testUpdatedObjectInfo{obj: desiredTemplate}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if err != nil { + t.Fatalf("expected template update to succeed when desired spec.versionID is empty: %v", err) + } + if created { + t.Fatal("expected update created=false") + } + + updatedTemplate, ok := updatedObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + t.Fatalf("expected *CoderTemplate from update, got %T", updatedObj) + } + if updatedTemplate.Spec.Running != currentTemplate.Spec.Running { + t.Fatalf("expected update response running=%t from current backend object, got %t", currentTemplate.Spec.Running, updatedTemplate.Spec.Running) + } + if updatedTemplate.Spec.VersionID != currentTemplate.Spec.VersionID { + t.Fatalf("expected update response spec.versionID %q from current backend object, got %q", currentTemplate.Spec.VersionID, updatedTemplate.Spec.VersionID) + } +} + +func TestTemplateStorageUpdateAllowsEmptyOptionalFieldsWhenTogglingRunning(t *testing.T) { + t.Parallel() + + server, _ := newMockCoderServer(t) + defer server.Close() + + templateStorage := NewTemplateStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := templateStorage.Get(ctx, "acme.starter-template", nil) + if err != nil { + t.Fatalf("expected template get to succeed: %v", err) + } + + currentTemplate, ok := currentObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + t.Fatalf("expected *CoderTemplate from get, got %T", currentObj) + } + if currentTemplate.Spec.DisplayName == "" || currentTemplate.Spec.Description == "" || currentTemplate.Spec.Icon == "" { + t.Fatal("expected current template optional fields to be populated") + } + + desiredTemplate := currentTemplate.DeepCopy() + desiredTemplate.Spec.Running = !currentTemplate.Spec.Running + desiredTemplate.Spec.DisplayName = "" + desiredTemplate.Spec.Description = "" + desiredTemplate.Spec.Icon = "" + + updatedObj, created, err := templateStorage.Update( + ctx, + desiredTemplate.Name, + testUpdatedObjectInfo{obj: desiredTemplate}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if err != nil { + t.Fatalf("expected template update to succeed when optional fields are empty: %v", err) + } + if created { + t.Fatal("expected update created=false") + } + + updatedTemplate, ok := updatedObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + t.Fatalf("expected *CoderTemplate from update, got %T", updatedObj) + } + if updatedTemplate.Spec.Running != currentTemplate.Spec.Running { + t.Fatalf("expected update response running=%t from current backend object, got %t", currentTemplate.Spec.Running, updatedTemplate.Spec.Running) + } + if updatedTemplate.Spec.DisplayName != currentTemplate.Spec.DisplayName { + t.Fatalf("expected update response spec.displayName %q from current backend object, got %q", currentTemplate.Spec.DisplayName, updatedTemplate.Spec.DisplayName) + } + if updatedTemplate.Spec.Description != currentTemplate.Spec.Description { + t.Fatalf("expected update response spec.description %q from current backend object, got %q", currentTemplate.Spec.Description, updatedTemplate.Spec.Description) + } + if updatedTemplate.Spec.Icon != currentTemplate.Spec.Icon { + t.Fatalf("expected update response spec.icon %q from current backend object, got %q", currentTemplate.Spec.Icon, updatedTemplate.Spec.Icon) + } +} + +func TestTemplateStorageUpdateRejectsDifferentVersionID(t *testing.T) { + t.Parallel() + + server, _ := newMockCoderServer(t) + defer server.Close() + + templateStorage := NewTemplateStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := templateStorage.Get(ctx, "acme.starter-template", nil) + if err != nil { + t.Fatalf("expected template get to succeed: %v", err) + } + + currentTemplate, ok := currentObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + t.Fatalf("expected *CoderTemplate from get, got %T", currentObj) + } + + desiredTemplate := currentTemplate.DeepCopy() + desiredTemplate.Spec.Running = !currentTemplate.Spec.Running + desiredTemplate.Spec.VersionID = uuid.New().String() + if desiredTemplate.Spec.VersionID == currentTemplate.Spec.VersionID { + t.Fatal("expected test fixture to use a different spec.versionID") + } + + _, _, err = templateStorage.Update( + ctx, + desiredTemplate.Name, + testUpdatedObjectInfo{obj: desiredTemplate}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest when changing spec.versionID, got %v", err) + } + if err == nil || !strings.Contains(err.Error(), "spec.running") { + t.Fatalf("expected immutable-field error mentioning spec.running, got %v", err) + } +} + +func TestTemplateStorageUpdateRejectsNonRunningSpecChanges(t *testing.T) { + t.Parallel() + + server, _ := newMockCoderServer(t) + defer server.Close() + + templateStorage := NewTemplateStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := templateStorage.Get(ctx, "acme.starter-template", nil) + if err != nil { + t.Fatalf("expected template get to succeed: %v", err) + } + + currentTemplate, ok := currentObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + t.Fatalf("expected *CoderTemplate from get, got %T", currentObj) + } + + desiredTemplate := currentTemplate.DeepCopy() + desiredTemplate.Spec.DisplayName = "Renamed Template" + + _, _, err = templateStorage.Update( + ctx, + desiredTemplate.Name, + testUpdatedObjectInfo{obj: desiredTemplate}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest when changing immutable template spec fields, got %v", err) + } + if err == nil || !strings.Contains(err.Error(), "spec.running") { + t.Fatalf("expected immutable-field error mentioning spec.running, got %v", err) + } +} + +func TestTemplateStorageUpdateRejectsMissingResourceVersion(t *testing.T) { + t.Parallel() + + server, _ := newMockCoderServer(t) + defer server.Close() + + templateStorage := NewTemplateStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := templateStorage.Get(ctx, "acme.starter-template", nil) + if err != nil { + t.Fatalf("expected template get to succeed: %v", err) + } + + currentTemplate, ok := currentObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + t.Fatalf("expected *CoderTemplate from get, got %T", currentObj) + } + if currentTemplate.ResourceVersion == "" { + t.Fatal("expected current template resourceVersion to be populated") + } + + desiredTemplate := currentTemplate.DeepCopy() + desiredTemplate.Spec.Running = !currentTemplate.Spec.Running + desiredTemplate.ResourceVersion = "" + + _, _, err = templateStorage.Update( + ctx, + desiredTemplate.Name, + testUpdatedObjectInfo{obj: desiredTemplate}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest for missing resourceVersion, got %v", err) + } + if err == nil || !strings.Contains(err.Error(), "metadata.resourceVersion is required for update") { + t.Fatalf("expected missing resourceVersion error message, got %v", err) + } +} + +func TestTemplateStorageUpdateRejectsStaleResourceVersion(t *testing.T) { + t.Parallel() + + server, _ := newMockCoderServer(t) + defer server.Close() + + templateStorage := NewTemplateStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := templateStorage.Get(ctx, "acme.starter-template", nil) + if err != nil { + t.Fatalf("expected template get to succeed: %v", err) + } + + currentTemplate, ok := currentObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + t.Fatalf("expected *CoderTemplate from get, got %T", currentObj) + } + + desiredTemplate := currentTemplate.DeepCopy() + desiredTemplate.Spec.Running = !currentTemplate.Spec.Running + desiredTemplate.ResourceVersion = currentTemplate.ResourceVersion + "-stale" + + _, _, err = templateStorage.Update( + ctx, + desiredTemplate.Name, + testUpdatedObjectInfo{obj: desiredTemplate}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if !apierrors.IsConflict(err) { + t.Fatalf("expected Conflict for stale resourceVersion, got %v", err) + } + if err == nil || !strings.Contains(err.Error(), "resource version mismatch") { + t.Fatalf("expected stale resourceVersion error message, got %v", err) + } +} + +func TestTemplateStorageUpdateRejectsMismatchedName(t *testing.T) { + t.Parallel() + + server, _ := newMockCoderServer(t) + defer server.Close() + + templateStorage := NewTemplateStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := templateStorage.Get(ctx, "acme.starter-template", nil) + if err != nil { + t.Fatalf("expected template get to succeed: %v", err) + } + + currentTemplate, ok := currentObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + t.Fatalf("expected *CoderTemplate from get, got %T", currentObj) + } + + desiredTemplate := currentTemplate.DeepCopy() + desiredTemplate.Spec.Running = !currentTemplate.Spec.Running + desiredTemplate.Name = "acme.other-template" + + _, _, err = templateStorage.Update( + ctx, + currentTemplate.Name, + testUpdatedObjectInfo{obj: desiredTemplate}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest for mismatched name, got %v", err) + } + if err == nil || !strings.Contains(err.Error(), "updated object metadata.name \"acme.other-template\" must match request name \"acme.starter-template\"") { + t.Fatalf("expected mismatched name error message, got %v", err) + } +} + +func TestTemplateStorageUpdateRejectsMismatchedNamespace(t *testing.T) { + t.Parallel() + + server, _ := newMockCoderServer(t) + defer server.Close() + + templateStorage := NewTemplateStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := templateStorage.Get(ctx, "acme.starter-template", nil) + if err != nil { + t.Fatalf("expected template get to succeed: %v", err) + } + + currentTemplate, ok := currentObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + t.Fatalf("expected *CoderTemplate from get, got %T", currentObj) + } + + desiredTemplate := currentTemplate.DeepCopy() + desiredTemplate.Spec.Running = !currentTemplate.Spec.Running + desiredTemplate.Namespace = "other-namespace" + + _, _, err = templateStorage.Update( + ctx, + desiredTemplate.Name, + testUpdatedObjectInfo{obj: desiredTemplate}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest for mismatched namespace, got %v", err) + } + if err == nil || !strings.Contains(err.Error(), "metadata.namespace \"other-namespace\" does not match request namespace \"control-plane\"") { + t.Fatalf("expected mismatched namespace error message, got %v", err) + } +} + +func TestWorkspaceStorageCRUDWithCoderSDK(t *testing.T) { + t.Parallel() + + server, state := newMockCoderServer(t) + defer server.Close() + + workspaceStorage := NewWorkspaceStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + listObj, err := workspaceStorage.List(ctx, nil) + if err != nil { + t.Fatalf("expected workspace list to succeed: %v", err) + } + + list, ok := listObj.(*aggregationv1alpha1.CoderWorkspaceList) + if !ok { + t.Fatalf("expected *CoderWorkspaceList, got %T", listObj) + } + if len(list.Items) != 1 { + t.Fatalf("expected one workspace in list, got %d", len(list.Items)) + } + if list.Items[0].Name != "acme.alice.dev-workspace" { + t.Fatalf("expected workspace name acme.alice.dev-workspace, got %q", list.Items[0].Name) + } + + obj, err := workspaceStorage.Get(ctx, "acme.alice.dev-workspace", nil) + if err != nil { + t.Fatalf("expected workspace get to succeed: %v", err) + } + + workspace, ok := obj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace, got %T", obj) + } + if !workspace.Spec.Running { + t.Fatal("expected initial workspace to be running") + } + + ttlMillis := int64(7200000) + autostartSchedule := "CRON_TZ=UTC 0 10 * * 1-5" + createObj := &aggregationv1alpha1.CoderWorkspace{ + ObjectMeta: metav1.ObjectMeta{Name: "acme.alice.ops-workspace"}, + Spec: aggregationv1alpha1.CoderWorkspaceSpec{ + Organization: "acme", + TemplateName: "starter-template", + Running: false, + TTLMillis: &ttlMillis, + AutostartSchedule: &autostartSchedule, + }, + } + + createdObj, err := workspaceStorage.Create(ctx, createObj, rest.ValidateAllObjectFunc, nil) + if err != nil { + t.Fatalf("expected workspace create to succeed: %v", err) + } + + createdWorkspace, ok := createdObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace from create, got %T", createdObj) + } + if createdWorkspace.Spec.Running { + t.Fatal("expected created workspace to be stopped when spec.running=false") + } + if !state.hasWorkspace("alice", "ops-workspace") { + t.Fatal("expected workspace to be persisted in mock server state") + } + if !containsTransition(state.buildTransitionsSnapshot(), codersdk.WorkspaceTransitionStop) { + t.Fatal("expected create to queue stop transition when running=false") + } + + desiredWorkspace := createdWorkspace.DeepCopy() + desiredWorkspace.Spec.Running = true + + updatedObj, created, err := workspaceStorage.Update( + ctx, + desiredWorkspace.Name, + testUpdatedObjectInfo{obj: desiredWorkspace}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if err != nil { + t.Fatalf("expected workspace update to succeed: %v", err) + } + if created { + t.Fatal("expected update created=false") + } + + updatedWorkspace, ok := updatedObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace from update, got %T", updatedObj) + } + if !updatedWorkspace.Spec.Running { + t.Fatal("expected updated workspace to be running") + } + if !containsTransition(state.buildTransitionsSnapshot(), codersdk.WorkspaceTransitionStart) { + t.Fatal("expected update to queue start transition") + } + + _, deleted, err := workspaceStorage.Delete(ctx, desiredWorkspace.Name, rest.ValidateAllObjectFunc, nil) + if err != nil { + t.Fatalf("expected workspace delete to succeed: %v", err) + } + if deleted { + t.Fatal("expected delete to report deleted=false for async delete transition") + } + if !containsTransition(state.buildTransitionsSnapshot(), codersdk.WorkspaceTransitionDelete) { + t.Fatal("expected delete to queue delete transition") + } +} + +func TestWorkspaceStorageCreateRejectsTemplateVersionIDFromDifferentTemplate(t *testing.T) { + t.Parallel() + + server, state := newMockCoderServer(t) + defer server.Close() + + workspaceStorage := NewWorkspaceStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + mismatchedTemplateVersionID := uuid.New() + state.setTemplateVersionTemplateID(mismatchedTemplateVersionID, uuid.New()) + + createObj := &aggregationv1alpha1.CoderWorkspace{ + ObjectMeta: metav1.ObjectMeta{Name: "acme.alice.mismatch-template-version-workspace"}, + Spec: aggregationv1alpha1.CoderWorkspaceSpec{ + Organization: "acme", + TemplateName: "starter-template", + TemplateVersionID: mismatchedTemplateVersionID.String(), + Running: true, + }, + } + + _, err := workspaceStorage.Create(ctx, createObj, rest.ValidateAllObjectFunc, nil) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest when templateVersionID belongs to a different template, got %v", err) + } + + expectedMessage := fmt.Sprintf( + "spec.templateVersionID %q does not belong to template %q", + mismatchedTemplateVersionID.String(), + "starter-template", + ) + if err == nil || !strings.Contains(err.Error(), expectedMessage) { + t.Fatalf("expected mismatched templateVersionID error message %q, got %v", expectedMessage, err) + } + if state.hasWorkspace("alice", "mismatch-template-version-workspace") { + t.Fatal("expected workspace create to be rejected before persistence") + } + if transitions := state.buildTransitionsSnapshot(); len(transitions) != 0 { + t.Fatalf("expected no workspace build transitions on mismatched templateVersionID, got %v", transitions) + } +} + +func TestWorkspaceStorageCreateAllowsMatchingTemplateVersionID(t *testing.T) { + t.Parallel() + + server, state := newMockCoderServer(t) + defer server.Close() + + workspaceStorage := NewWorkspaceStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + templateVersionID, ok := state.workspaceLatestBuildTemplateVersionID("alice", "dev-workspace") + if !ok { + t.Fatal("expected workspace template version ID in mock server state") + } + if templateVersionID == uuid.Nil { + t.Fatal("expected workspace template version ID to be non-nil") + } + + createObj := &aggregationv1alpha1.CoderWorkspace{ + ObjectMeta: metav1.ObjectMeta{Name: "acme.alice.matching-template-version-workspace"}, + Spec: aggregationv1alpha1.CoderWorkspaceSpec{ + Organization: "acme", + TemplateName: "starter-template", + TemplateVersionID: templateVersionID.String(), + Running: true, + }, + } + + createdObj, err := workspaceStorage.Create(ctx, createObj, rest.ValidateAllObjectFunc, nil) + if err != nil { + t.Fatalf("expected workspace create to succeed for matching templateVersionID: %v", err) + } + + createdWorkspace, ok := createdObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace from create, got %T", createdObj) + } + if createdWorkspace.Spec.TemplateVersionID != templateVersionID.String() { + t.Fatalf( + "expected created spec.templateVersionID %q, got %q", + templateVersionID.String(), + createdWorkspace.Spec.TemplateVersionID, + ) + } + if !state.hasWorkspace("alice", "matching-template-version-workspace") { + t.Fatal("expected workspace to be persisted in mock server state") + } + if transitions := state.buildTransitionsSnapshot(); len(transitions) != 0 { + t.Fatalf("expected no workspace build transitions when spec.running=true, got %v", transitions) + } +} + +func TestWorkspaceStorageUpdateRejectsNonRunningSpecChanges(t *testing.T) { + t.Parallel() + + server, _ := newMockCoderServer(t) + defer server.Close() + + workspaceStorage := NewWorkspaceStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := workspaceStorage.Get(ctx, "acme.alice.dev-workspace", nil) + if err != nil { + t.Fatalf("expected workspace get to succeed: %v", err) + } + + currentWorkspace, ok := currentObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace from get, got %T", currentObj) + } + + desiredWorkspace := currentWorkspace.DeepCopy() + desiredWorkspace.Spec.TemplateName = "renamed-template" + + _, _, err = workspaceStorage.Update( + ctx, + desiredWorkspace.Name, + testUpdatedObjectInfo{obj: desiredWorkspace}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest when changing immutable workspace spec fields, got %v", err) + } + if err == nil || !strings.Contains(err.Error(), "spec.running") { + t.Fatalf("expected immutable-field error mentioning spec.running, got %v", err) + } +} + +func TestWorkspaceStorageUpdateAllowsPinnedTemplateVersionIDWhenTogglingRunning(t *testing.T) { + t.Parallel() + + server, state := newMockCoderServer(t) + defer server.Close() + + workspaceStorage := NewWorkspaceStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := workspaceStorage.Get(ctx, "acme.alice.dev-workspace", nil) + if err != nil { + t.Fatalf("expected workspace get to succeed: %v", err) + } + + currentWorkspace, ok := currentObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace from get, got %T", currentObj) + } + + templateVersionID, ok := state.workspaceLatestBuildTemplateVersionID("alice", "dev-workspace") + if !ok { + t.Fatal("expected workspace template version ID in mock server state") + } + if templateVersionID == uuid.Nil { + t.Fatal("expected workspace template version ID to be non-nil") + } + + desiredWorkspace := currentWorkspace.DeepCopy() + desiredWorkspace.Spec.TemplateVersionID = templateVersionID.String() + desiredWorkspace.Spec.Running = !currentWorkspace.Spec.Running + + updatedObj, created, err := workspaceStorage.Update( + ctx, + desiredWorkspace.Name, + testUpdatedObjectInfo{obj: desiredWorkspace}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if err != nil { + t.Fatalf("expected workspace update to succeed when templateVersionID is unchanged: %v", err) + } + if created { + t.Fatal("expected update created=false") + } + + updatedWorkspace, ok := updatedObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace from update, got %T", updatedObj) + } + if updatedWorkspace.Spec.Running != desiredWorkspace.Spec.Running { + t.Fatalf("expected updated running=%t, got %t", desiredWorkspace.Spec.Running, updatedWorkspace.Spec.Running) + } + if updatedWorkspace.Spec.TemplateVersionID != templateVersionID.String() { + t.Fatalf( + "expected updated templateVersionID %q, got %q", + templateVersionID.String(), + updatedWorkspace.Spec.TemplateVersionID, + ) + } + + expectedTransition := codersdk.WorkspaceTransitionStop + if desiredWorkspace.Spec.Running { + expectedTransition = codersdk.WorkspaceTransitionStart + } + if !containsTransition(state.buildTransitionsSnapshot(), expectedTransition) { + t.Fatalf("expected update to queue %q transition", expectedTransition) + } +} + +func TestWorkspaceStorageUpdateAllowsEmptyTemplateVersionIDWhenTogglingRunning(t *testing.T) { + t.Parallel() + + server, state := newMockCoderServer(t) + defer server.Close() + + workspaceStorage := NewWorkspaceStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := workspaceStorage.Get(ctx, "acme.alice.dev-workspace", nil) + if err != nil { + t.Fatalf("expected workspace get to succeed: %v", err) + } + + currentWorkspace, ok := currentObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace from get, got %T", currentObj) + } + if currentWorkspace.Spec.TemplateVersionID == "" { + t.Fatal("expected current workspace spec.templateVersionID to be populated") + } + + desiredWorkspace := currentWorkspace.DeepCopy() + desiredWorkspace.Spec.TemplateVersionID = "" + desiredWorkspace.Spec.Running = !currentWorkspace.Spec.Running + + updatedObj, created, err := workspaceStorage.Update( + ctx, + desiredWorkspace.Name, + testUpdatedObjectInfo{obj: desiredWorkspace}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if err != nil { + t.Fatalf("expected workspace update to succeed when desired spec.templateVersionID is empty: %v", err) + } + if created { + t.Fatal("expected update created=false") + } + + updatedWorkspace, ok := updatedObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace from update, got %T", updatedObj) + } + if updatedWorkspace.Spec.Running != desiredWorkspace.Spec.Running { + t.Fatalf("expected updated running=%t, got %t", desiredWorkspace.Spec.Running, updatedWorkspace.Spec.Running) + } + if updatedWorkspace.Spec.TemplateVersionID != currentWorkspace.Spec.TemplateVersionID { + t.Fatalf( + "expected updated templateVersionID %q, got %q", + currentWorkspace.Spec.TemplateVersionID, + updatedWorkspace.Spec.TemplateVersionID, + ) + } + + expectedTransition := codersdk.WorkspaceTransitionStop + if desiredWorkspace.Spec.Running { + expectedTransition = codersdk.WorkspaceTransitionStart + } + if !containsTransition(state.buildTransitionsSnapshot(), expectedTransition) { + t.Fatalf("expected update to queue %q transition", expectedTransition) + } +} + +func TestWorkspaceStorageUpdateAllowsNilOptionalFieldsWhenTogglingRunning(t *testing.T) { + t.Parallel() + + server, state := newMockCoderServer(t) + defer server.Close() + + workspaceStorage := NewWorkspaceStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := workspaceStorage.Get(ctx, "acme.alice.dev-workspace", nil) + if err != nil { + t.Fatalf("expected workspace get to succeed: %v", err) + } + + currentWorkspace, ok := currentObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace from get, got %T", currentObj) + } + if currentWorkspace.Spec.TTLMillis == nil || currentWorkspace.Spec.AutostartSchedule == nil { + t.Fatal("expected current workspace optional fields to be populated") + } + + desiredWorkspace := currentWorkspace.DeepCopy() + desiredWorkspace.Spec.Running = !currentWorkspace.Spec.Running + desiredWorkspace.Spec.TTLMillis = nil + desiredWorkspace.Spec.AutostartSchedule = nil + + updatedObj, created, err := workspaceStorage.Update( + ctx, + desiredWorkspace.Name, + testUpdatedObjectInfo{obj: desiredWorkspace}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if err != nil { + t.Fatalf("expected workspace update to succeed when optional fields are nil: %v", err) + } + if created { + t.Fatal("expected update created=false") + } + + updatedWorkspace, ok := updatedObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace from update, got %T", updatedObj) + } + if updatedWorkspace.Spec.Running != desiredWorkspace.Spec.Running { + t.Fatalf("expected updated running=%t, got %t", desiredWorkspace.Spec.Running, updatedWorkspace.Spec.Running) + } + if updatedWorkspace.Spec.TTLMillis == nil || *updatedWorkspace.Spec.TTLMillis != *currentWorkspace.Spec.TTLMillis { + t.Fatalf( + "expected returned spec.ttlMillis to remain %v, got %v", + *currentWorkspace.Spec.TTLMillis, + updatedWorkspace.Spec.TTLMillis, + ) + } + if updatedWorkspace.Spec.AutostartSchedule == nil || *updatedWorkspace.Spec.AutostartSchedule != *currentWorkspace.Spec.AutostartSchedule { + t.Fatalf( + "expected returned spec.autostartSchedule to remain %q, got %v", + *currentWorkspace.Spec.AutostartSchedule, + updatedWorkspace.Spec.AutostartSchedule, + ) + } + + expectedTransition := codersdk.WorkspaceTransitionStop + if desiredWorkspace.Spec.Running { + expectedTransition = codersdk.WorkspaceTransitionStart + } + if !containsTransition(state.buildTransitionsSnapshot(), expectedTransition) { + t.Fatalf("expected update to queue %q transition", expectedTransition) + } +} + +func TestWorkspaceStorageUpdateRejectsDifferentTTLMillis(t *testing.T) { + t.Parallel() + + server, state := newMockCoderServer(t) + defer server.Close() + + workspaceStorage := NewWorkspaceStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := workspaceStorage.Get(ctx, "acme.alice.dev-workspace", nil) + if err != nil { + t.Fatalf("expected workspace get to succeed: %v", err) + } + + currentWorkspace, ok := currentObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace from get, got %T", currentObj) + } + if currentWorkspace.Spec.TTLMillis == nil { + t.Fatal("expected current workspace spec.ttlMillis to be populated") + } + + differentTTLMillis := *currentWorkspace.Spec.TTLMillis + 60000 + desiredWorkspace := currentWorkspace.DeepCopy() + desiredWorkspace.Spec.Running = !currentWorkspace.Spec.Running + desiredWorkspace.Spec.TTLMillis = &differentTTLMillis + + _, _, err = workspaceStorage.Update( + ctx, + desiredWorkspace.Name, + testUpdatedObjectInfo{obj: desiredWorkspace}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest when changing spec.ttlMillis, got %v", err) + } + if err == nil || !strings.Contains(err.Error(), "spec.running") { + t.Fatalf("expected immutable-field error mentioning spec.running, got %v", err) + } + if transitions := state.buildTransitionsSnapshot(); len(transitions) != 0 { + t.Fatalf("expected no workspace build transitions on immutable-field error, got %v", transitions) + } +} + +func TestWorkspaceStorageUpdateRejectsDifferentTemplateVersionID(t *testing.T) { + t.Parallel() + + server, state := newMockCoderServer(t) + defer server.Close() + + workspaceStorage := NewWorkspaceStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := workspaceStorage.Get(ctx, "acme.alice.dev-workspace", nil) + if err != nil { + t.Fatalf("expected workspace get to succeed: %v", err) + } + + currentWorkspace, ok := currentObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace from get, got %T", currentObj) + } + + desiredWorkspace := currentWorkspace.DeepCopy() + desiredWorkspace.Spec.Running = !currentWorkspace.Spec.Running + desiredWorkspace.Spec.TemplateVersionID = uuid.New().String() + if desiredWorkspace.Spec.TemplateVersionID == currentWorkspace.Spec.TemplateVersionID { + t.Fatal("expected test fixture to use a different spec.templateVersionID") + } + + _, _, err = workspaceStorage.Update( + ctx, + desiredWorkspace.Name, + testUpdatedObjectInfo{obj: desiredWorkspace}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest when changing spec.templateVersionID, got %v", err) + } + if err == nil || !strings.Contains(err.Error(), "spec.running") { + t.Fatalf("expected immutable-field error mentioning spec.running, got %v", err) + } + if transitions := state.buildTransitionsSnapshot(); len(transitions) != 0 { + t.Fatalf("expected no workspace build transitions on immutable-field error, got %v", transitions) + } +} + +func TestWorkspaceStorageUpdateRejectsMissingResourceVersion(t *testing.T) { + t.Parallel() + + server, state := newMockCoderServer(t) + defer server.Close() + + workspaceStorage := NewWorkspaceStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := workspaceStorage.Get(ctx, "acme.alice.dev-workspace", nil) + if err != nil { + t.Fatalf("expected workspace get to succeed: %v", err) + } + + currentWorkspace, ok := currentObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace from get, got %T", currentObj) + } + if currentWorkspace.ResourceVersion == "" { + t.Fatal("expected current workspace resourceVersion to be populated") + } + + desiredWorkspace := currentWorkspace.DeepCopy() + desiredWorkspace.Spec.Running = !currentWorkspace.Spec.Running + desiredWorkspace.ResourceVersion = "" + + _, _, err = workspaceStorage.Update( + ctx, + desiredWorkspace.Name, + testUpdatedObjectInfo{obj: desiredWorkspace}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest for missing resourceVersion, got %v", err) + } + if err == nil || !strings.Contains(err.Error(), "metadata.resourceVersion is required for update") { + t.Fatalf("expected missing resourceVersion error message, got %v", err) + } + if transitions := state.buildTransitionsSnapshot(); len(transitions) != 0 { + t.Fatalf("expected no workspace build transitions when resourceVersion is missing, got %v", transitions) + } +} - workspaceStorage := NewWorkspaceStorage() - ctx := genericapirequest.WithNamespace(context.Background(), "default") +func TestWorkspaceStorageUpdateRejectsStaleResourceVersion(t *testing.T) { + t.Parallel() - obj, err := workspaceStorage.List(ctx, nil) + server, state := newMockCoderServer(t) + defer server.Close() + + workspaceStorage := NewWorkspaceStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := workspaceStorage.Get(ctx, "acme.alice.dev-workspace", nil) if err != nil { - t.Fatalf("expected workspace list to succeed: %v", err) + t.Fatalf("expected workspace get to succeed: %v", err) } - list, ok := obj.(*aggregationv1alpha1.CoderWorkspaceList) + currentWorkspace, ok := currentObj.(*aggregationv1alpha1.CoderWorkspace) if !ok { - t.Fatalf("expected *CoderWorkspaceList, got %T", obj) + t.Fatalf("expected *CoderWorkspace from get, got %T", currentObj) } - if len(list.Items) == 0 { - t.Fatal("expected non-empty workspace list") + + desiredWorkspace := currentWorkspace.DeepCopy() + desiredWorkspace.Spec.Running = !currentWorkspace.Spec.Running + desiredWorkspace.ResourceVersion = currentWorkspace.ResourceVersion + "-stale" + + _, _, err = workspaceStorage.Update( + ctx, + desiredWorkspace.Name, + testUpdatedObjectInfo{obj: desiredWorkspace}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if !apierrors.IsConflict(err) { + t.Fatalf("expected Conflict for stale resourceVersion, got %v", err) + } + if err == nil || !strings.Contains(err.Error(), "resource version mismatch") { + t.Fatalf("expected stale resourceVersion error message, got %v", err) + } + if transitions := state.buildTransitionsSnapshot(); len(transitions) != 0 { + t.Fatalf("expected no workspace build transitions on stale resourceVersion conflict, got %v", transitions) } } -func TestWorkspaceStorageGet(t *testing.T) { - t.Helper() +func TestWorkspaceStorageUpdateRejectsMismatchedNamespace(t *testing.T) { + t.Parallel() - workspaceStorage := NewWorkspaceStorage() - ctx := genericapirequest.WithNamespace(context.Background(), "default") + server, state := newMockCoderServer(t) + defer server.Close() - obj, err := workspaceStorage.Get(ctx, "dev-workspace", nil) + workspaceStorage := NewWorkspaceStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + currentObj, err := workspaceStorage.Get(ctx, "acme.alice.dev-workspace", nil) if err != nil { t.Fatalf("expected workspace get to succeed: %v", err) } - workspace, ok := obj.(*aggregationv1alpha1.CoderWorkspace) + currentWorkspace, ok := currentObj.(*aggregationv1alpha1.CoderWorkspace) if !ok { - t.Fatalf("expected *CoderWorkspace, got %T", obj) + t.Fatalf("expected *CoderWorkspace from get, got %T", currentObj) + } + + desiredWorkspace := currentWorkspace.DeepCopy() + desiredWorkspace.Spec.Running = !currentWorkspace.Spec.Running + desiredWorkspace.Namespace = "other-namespace" + + _, _, err = workspaceStorage.Update( + ctx, + desiredWorkspace.Name, + testUpdatedObjectInfo{obj: desiredWorkspace}, + nil, + rest.ValidateAllObjectUpdateFunc, + false, + nil, + ) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest for mismatched namespace, got %v", err) } - if workspace.Name != "dev-workspace" { - t.Fatalf("expected dev-workspace, got %q", workspace.Name) + if err == nil || !strings.Contains(err.Error(), "metadata.namespace \"other-namespace\" does not match request namespace \"control-plane\"") { + t.Fatalf("expected mismatched namespace error message, got %v", err) + } + if transitions := state.buildTransitionsSnapshot(); len(transitions) != 0 { + t.Fatalf("expected no workspace build transitions on namespace validation error, got %v", transitions) } } -func TestWorkspaceStorageGetNotFound(t *testing.T) { - t.Helper() +func TestWorkspaceStorageCreateRunningFalseReturnsWorkspaceWhenStopBuildFails(t *testing.T) { + t.Parallel() + + server, state := newMockCoderServer(t) + defer server.Close() + + state.setBuildTransitionFailure(codersdk.WorkspaceTransitionStop, http.StatusBadRequest) + + workspaceStorage := NewWorkspaceStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + createObj := &aggregationv1alpha1.CoderWorkspace{ + ObjectMeta: metav1.ObjectMeta{Name: "acme.alice.ops-workspace"}, + Spec: aggregationv1alpha1.CoderWorkspaceSpec{ + Organization: "acme", + TemplateName: "starter-template", + Running: false, + }, + } + + createdObj, err := workspaceStorage.Create(ctx, createObj, rest.ValidateAllObjectFunc, nil) + if err != nil { + t.Fatalf("expected workspace create to succeed even when stop build fails: %v", err) + } + + createdWorkspace, ok := createdObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace from create, got %T", createdObj) + } + if !createdWorkspace.Spec.Running { + t.Fatal("expected created workspace to remain running when stop build fails") + } + if !state.hasWorkspace("alice", "ops-workspace") { + t.Fatal("expected workspace to be persisted in mock server state") + } + if containsTransition(state.buildTransitionsSnapshot(), codersdk.WorkspaceTransitionStop) { + t.Fatal("expected failed stop transition to be absent from transition history") + } +} + +func TestWorkspaceStorageGetOrgMismatchReturnsNotFound(t *testing.T) { + t.Parallel() - workspaceStorage := NewWorkspaceStorage() - ctx := genericapirequest.WithNamespace(context.Background(), "default") + server, _ := newMockCoderServer(t) + defer server.Close() - _, err := workspaceStorage.Get(ctx, "does-not-exist", nil) + workspaceStorage := NewWorkspaceStorage(newTestClientProvider(t, server.URL)) + ctx := namespacedContext("control-plane") + + _, err := workspaceStorage.Get(ctx, "otherorg.alice.dev-workspace", nil) if !apierrors.IsNotFound(err) { - t.Fatalf("expected NotFound error, got %v", err) + t.Fatalf("expected NotFound when organization segment mismatches workspace org, got %v", err) } } -func TestTemplateStorageList(t *testing.T) { - t.Helper() +func TestWorkspaceStorageListAllowsAllNamespacesRequest(t *testing.T) { + t.Parallel() - templateStorage := NewTemplateStorage() - ctx := genericapirequest.WithNamespace(context.Background(), "default") + server, _ := newMockCoderServer(t) + defer server.Close() - obj, err := templateStorage.List(ctx, nil) + workspaceStorage := NewWorkspaceStorage(newTestClientProvider(t, server.URL)) + + listObj, err := workspaceStorage.List(context.Background(), nil) if err != nil { - t.Fatalf("expected template list to succeed: %v", err) + t.Fatalf("expected all-namespaces list to succeed, got %v", err) } - - list, ok := obj.(*aggregationv1alpha1.CoderTemplateList) + list, ok := listObj.(*aggregationv1alpha1.CoderWorkspaceList) if !ok { - t.Fatalf("expected *CoderTemplateList, got %T", obj) + t.Fatalf("expected *CoderWorkspaceList, got %T", listObj) } if len(list.Items) == 0 { - t.Fatal("expected non-empty template list") + t.Fatal("expected at least one workspace in list") + } +} + +func TestWorkspaceStorageListPreservesProviderStatusErrors(t *testing.T) { + t.Parallel() + + server, _ := newMockCoderServer(t) + defer server.Close() + + parsedURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("parse mock server URL %q: %v", server.URL, err) + } + client := codersdk.New(parsedURL) + client.SetSessionToken("test-session-token") + + workspaceStorage := NewWorkspaceStorage(&coder.StaticClientProvider{ + Client: client, + Namespace: "control-plane", + }) + + _, err = workspaceStorage.List(namespacedContext("other-namespace"), nil) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest from provider namespace restriction, got %v", err) + } + assertTopLevelStatusError(t, err) +} + +func assertTopLevelStatusError(t *testing.T, err error) { + t.Helper() + + if err == nil { + t.Fatal("expected error to be non-nil") + } + + if reflect.TypeOf(err) != reflect.TypeOf(&apierrors.StatusError{}) { + t.Fatalf("expected top-level error type *apierrors.StatusError, got %T", err) + } +} + +type testUpdatedObjectInfo struct { + obj runtime.Object + err error +} + +func (i testUpdatedObjectInfo) Preconditions() *metav1.Preconditions { + return nil +} + +func (i testUpdatedObjectInfo) UpdatedObject(context.Context, runtime.Object) (runtime.Object, error) { + if i.err != nil { + return nil, i.err + } + if i.obj == nil { + return nil, fmt.Errorf("assertion failed: updated object must not be nil") + } + + return i.obj, nil +} + +type mockCoderServerState struct { + mu sync.Mutex + + organization codersdk.Organization + + templatesByID map[uuid.UUID]codersdk.Template + templateIDsByOrg map[string]map[string]uuid.UUID + templateVersionsByID map[uuid.UUID]codersdk.TemplateVersion + workspacesByID map[uuid.UUID]codersdk.Workspace + workspaceIDsByUser map[string]map[string]uuid.UUID + + buildTransitions []codersdk.WorkspaceTransition + failBuildTransitions map[codersdk.WorkspaceTransition]int +} + +func newMockCoderServer(t *testing.T) (*httptest.Server, *mockCoderServerState) { + t.Helper() + + now := time.Date(2026, time.January, 1, 12, 0, 0, 0, time.UTC) + orgID := uuid.New() + templateID := uuid.New() + activeVersionID := uuid.New() + workspaceID := uuid.New() + workspaceBuildID := uuid.New() + ttlMillis := int64(3600000) + autostartSchedule := "CRON_TZ=UTC 0 9 * * 1-5" + + organization := codersdk.Organization{ + MinimalOrganization: codersdk.MinimalOrganization{ + ID: orgID, + Name: "acme", + DisplayName: "Acme", + }, + CreatedAt: now.Add(-24 * time.Hour), + UpdatedAt: now.Add(-1 * time.Hour), + } + + template := codersdk.Template{ + ID: templateID, + CreatedAt: now.Add(-12 * time.Hour), + UpdatedAt: now.Add(-2 * time.Hour), + OrganizationID: orgID, + OrganizationName: "acme", + Name: "starter-template", + DisplayName: "Starter Template", + Description: "Default development template", + Icon: "/icons/starter.png", + ActiveVersionID: activeVersionID, + } + + templateIDForVersion := template.ID + templateVersion := codersdk.TemplateVersion{ + ID: activeVersionID, + TemplateID: &templateIDForVersion, + OrganizationID: orgID, + CreatedAt: now.Add(-11 * time.Hour), + UpdatedAt: now.Add(-2 * time.Hour), + Name: "starter-template-v1", + Message: "initial version", + } + + workspace := codersdk.Workspace{ + ID: workspaceID, + CreatedAt: now.Add(-8 * time.Hour), + UpdatedAt: now.Add(-30 * time.Minute), + OwnerName: "alice", + OrganizationID: orgID, + OrganizationName: "acme", + TemplateID: templateID, + TemplateName: "starter-template", + Name: "dev-workspace", + TTLMillis: &ttlMillis, + AutostartSchedule: &autostartSchedule, + LastUsedAt: now.Add(-10 * time.Minute), + LatestBuild: codersdk.WorkspaceBuild{ + ID: workspaceBuildID, + WorkspaceID: workspaceID, + WorkspaceName: "dev-workspace", + WorkspaceOwnerName: "alice", + TemplateVersionID: activeVersionID, + Transition: codersdk.WorkspaceTransitionStart, + Status: codersdk.WorkspaceStatusRunning, + CreatedAt: now.Add(-30 * time.Minute), + UpdatedAt: now.Add(-30 * time.Minute), + }, + } + + state := &mockCoderServerState{ + organization: organization, + templatesByID: map[uuid.UUID]codersdk.Template{ + template.ID: template, + }, + templateIDsByOrg: map[string]map[string]uuid.UUID{ + "acme": { + template.Name: template.ID, + }, + }, + templateVersionsByID: map[uuid.UUID]codersdk.TemplateVersion{ + templateVersion.ID: templateVersion, + }, + workspacesByID: map[uuid.UUID]codersdk.Workspace{ + workspace.ID: workspace, + }, + workspaceIDsByUser: map[string]map[string]uuid.UUID{ + "alice": { + workspace.Name: workspace.ID, + }, + }, + buildTransitions: []codersdk.WorkspaceTransition{}, + failBuildTransitions: map[codersdk.WorkspaceTransition]int{}, } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + state.handleRequest(t, w, r) + })) + + return server, state } -func TestTemplateStorageGet(t *testing.T) { +func (s *mockCoderServerState) handleRequest(t *testing.T, w http.ResponseWriter, r *http.Request) { t.Helper() - templateStorage := NewTemplateStorage() - ctx := genericapirequest.WithNamespace(context.Background(), "default") + segments := splitPath(r.URL.Path) + + switch { + case r.Method == http.MethodGet && hasSegments(segments, "api", "v2", "organizations") && len(segments) == 4: + s.handleGetOrganization(w, segments[3]) + return + case r.Method == http.MethodGet && hasSegments(segments, "api", "v2", "templates") && len(segments) == 3: + s.handleListTemplates(w) + return + case r.Method == http.MethodGet && hasSegments(segments, "api", "v2", "organizations") && len(segments) == 6 && segments[4] == "templates": + s.handleGetTemplateByName(w, segments[3], segments[5]) + return + case r.Method == http.MethodPost && hasSegments(segments, "api", "v2", "organizations") && len(segments) == 5 && segments[4] == "templates": + s.handleCreateTemplate(w, r, segments[3]) + return + case r.Method == http.MethodDelete && hasSegments(segments, "api", "v2", "templates") && len(segments) == 4: + s.handleDeleteTemplate(w, segments[3]) + return + case r.Method == http.MethodGet && hasSegments(segments, "api", "v2", "templateversions") && len(segments) == 4: + s.handleGetTemplateVersion(w, segments[3]) + return + case r.Method == http.MethodGet && hasSegments(segments, "api", "v2", "workspaces") && len(segments) == 3: + s.handleListWorkspaces(w) + return + case r.Method == http.MethodGet && hasSegments(segments, "api", "v2", "users") && len(segments) == 6 && segments[4] == "workspace": + s.handleGetWorkspace(w, segments[3], segments[5]) + return + case r.Method == http.MethodPost && hasSegments(segments, "api", "v2", "users") && len(segments) == 5 && segments[4] == "workspaces": + s.handleCreateWorkspace(w, r, segments[3]) + return + case r.Method == http.MethodPost && hasSegments(segments, "api", "v2", "workspaces") && len(segments) == 5 && segments[4] == "builds": + s.handleCreateWorkspaceBuild(w, r, segments[3]) + return + default: + writeCoderError(w, http.StatusNotFound, fmt.Sprintf("unexpected route: %s %s", r.Method, r.URL.Path)) + return + } +} + +func (s *mockCoderServerState) handleGetOrganization(w http.ResponseWriter, orgSegment string) { + s.mu.Lock() + defer s.mu.Unlock() + + if orgSegment != s.organization.Name && orgSegment != s.organization.ID.String() { + writeCoderError(w, http.StatusNotFound, "organization not found") + return + } + + writeJSON(w, http.StatusOK, s.organization) +} + +func (s *mockCoderServerState) handleListTemplates(w http.ResponseWriter) { + s.mu.Lock() + defer s.mu.Unlock() + + templates := make([]codersdk.Template, 0, len(s.templatesByID)) + for _, template := range s.templatesByID { + templates = append(templates, template) + } + sort.Slice(templates, func(i, j int) bool { + if templates[i].OrganizationName == templates[j].OrganizationName { + return templates[i].Name < templates[j].Name + } + return templates[i].OrganizationName < templates[j].OrganizationName + }) + + writeJSON(w, http.StatusOK, templates) +} + +func (s *mockCoderServerState) handleGetTemplateByName(w http.ResponseWriter, orgSegment, templateName string) { + s.mu.Lock() + defer s.mu.Unlock() + + if orgSegment != s.organization.Name && orgSegment != s.organization.ID.String() { + writeCoderError(w, http.StatusNotFound, "organization not found") + return + } + + orgTemplates, ok := s.templateIDsByOrg[s.organization.Name] + if !ok { + writeCoderError(w, http.StatusNotFound, "template not found") + return + } + templateID, ok := orgTemplates[templateName] + if !ok { + writeCoderError(w, http.StatusNotFound, "template not found") + return + } + template := s.templatesByID[templateID] + + writeJSON(w, http.StatusOK, template) +} + +func (s *mockCoderServerState) handleCreateTemplate(w http.ResponseWriter, r *http.Request, orgSegment string) { + s.mu.Lock() + defer s.mu.Unlock() + + if orgSegment != s.organization.Name && orgSegment != s.organization.ID.String() { + writeCoderError(w, http.StatusNotFound, "organization not found") + return + } + + var request codersdk.CreateTemplateRequest + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + writeCoderError(w, http.StatusBadRequest, fmt.Sprintf("decode create template request: %v", err)) + return + } + + now := time.Now().UTC() + template := codersdk.Template{ + ID: uuid.New(), + CreatedAt: now, + UpdatedAt: now, + OrganizationID: s.organization.ID, + OrganizationName: s.organization.Name, + Name: request.Name, + DisplayName: request.DisplayName, + Description: request.Description, + Icon: request.Icon, + ActiveVersionID: request.VersionID, + } + + s.templatesByID[template.ID] = template + orgTemplates, ok := s.templateIDsByOrg[s.organization.Name] + if !ok { + orgTemplates = map[string]uuid.UUID{} + s.templateIDsByOrg[s.organization.Name] = orgTemplates + } + orgTemplates[template.Name] = template.ID + + writeJSON(w, http.StatusCreated, template) +} + +func (s *mockCoderServerState) handleDeleteTemplate(w http.ResponseWriter, templateIDSegment string) { + s.mu.Lock() + defer s.mu.Unlock() - obj, err := templateStorage.Get(ctx, "starter-template", nil) + templateID, err := uuid.Parse(templateIDSegment) if err != nil { - t.Fatalf("expected template get to succeed: %v", err) + writeCoderError(w, http.StatusBadRequest, fmt.Sprintf("invalid template id %q", templateIDSegment)) + return } - template, ok := obj.(*aggregationv1alpha1.CoderTemplate) + template, ok := s.templatesByID[templateID] if !ok { - t.Fatalf("expected *CoderTemplate, got %T", obj) + writeCoderError(w, http.StatusNotFound, "template not found") + return + } + + delete(s.templatesByID, templateID) + orgTemplates := s.templateIDsByOrg[template.OrganizationName] + delete(orgTemplates, template.Name) + + writeJSON(w, http.StatusOK, map[string]string{"message": "template deleted"}) +} + +func (s *mockCoderServerState) handleGetTemplateVersion(w http.ResponseWriter, templateVersionIDSegment string) { + s.mu.Lock() + defer s.mu.Unlock() + + templateVersionID, err := uuid.Parse(templateVersionIDSegment) + if err != nil { + writeCoderError(w, http.StatusBadRequest, fmt.Sprintf("invalid template version id %q", templateVersionIDSegment)) + return + } + + templateVersion, ok := s.templateVersionsByID[templateVersionID] + if !ok { + writeCoderError(w, http.StatusNotFound, "template version not found") + return + } + + writeJSON(w, http.StatusOK, templateVersion) +} + +func (s *mockCoderServerState) handleListWorkspaces(w http.ResponseWriter) { + s.mu.Lock() + defer s.mu.Unlock() + + workspaces := make([]codersdk.Workspace, 0, len(s.workspacesByID)) + for _, workspace := range s.workspacesByID { + workspaces = append(workspaces, workspace) + } + sort.Slice(workspaces, func(i, j int) bool { + if workspaces[i].OrganizationName == workspaces[j].OrganizationName { + if workspaces[i].OwnerName == workspaces[j].OwnerName { + return workspaces[i].Name < workspaces[j].Name + } + return workspaces[i].OwnerName < workspaces[j].OwnerName + } + return workspaces[i].OrganizationName < workspaces[j].OrganizationName + }) + + writeJSON(w, http.StatusOK, codersdk.WorkspacesResponse{Workspaces: workspaces, Count: len(workspaces)}) +} + +func (s *mockCoderServerState) handleGetWorkspace(w http.ResponseWriter, owner, workspaceName string) { + s.mu.Lock() + defer s.mu.Unlock() + + userWorkspaces, ok := s.workspaceIDsByUser[owner] + if !ok { + writeCoderError(w, http.StatusNotFound, "workspace not found") + return + } + workspaceID, ok := userWorkspaces[workspaceName] + if !ok { + writeCoderError(w, http.StatusNotFound, "workspace not found") + return + } + workspace := s.workspacesByID[workspaceID] + + writeJSON(w, http.StatusOK, workspace) +} + +func (s *mockCoderServerState) handleCreateWorkspace(w http.ResponseWriter, r *http.Request, user string) { + s.mu.Lock() + defer s.mu.Unlock() + + var request codersdk.CreateWorkspaceRequest + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + writeCoderError(w, http.StatusBadRequest, fmt.Sprintf("decode create workspace request: %v", err)) + return + } + + templateID := request.TemplateID + templateVersionID := request.TemplateVersionID + if templateID == uuid.Nil && templateVersionID == uuid.Nil { + writeCoderError(w, http.StatusBadRequest, "template_id or template_version_id is required") + return + } + + if templateVersionID != uuid.Nil { + templateVersion, ok := s.templateVersionsByID[templateVersionID] + if !ok { + writeCoderError(w, http.StatusNotFound, "template version not found") + return + } + if templateVersion.TemplateID == nil || *templateVersion.TemplateID == uuid.Nil { + writeCoderError( + w, + http.StatusBadRequest, + fmt.Sprintf("template version %q is not associated with a template", templateVersionID.String()), + ) + return + } + if templateID != uuid.Nil && *templateVersion.TemplateID != templateID { + writeCoderError( + w, + http.StatusBadRequest, + fmt.Sprintf( + "template version %q does not belong to template %q", + templateVersionID.String(), + templateID.String(), + ), + ) + return + } + + templateID = *templateVersion.TemplateID + } + + template, ok := s.templatesByID[templateID] + if !ok { + writeCoderError(w, http.StatusNotFound, "template not found") + return + } + if templateVersionID == uuid.Nil { + templateVersionID = template.ActiveVersionID + } + + now := time.Now().UTC() + workspaceID := uuid.New() + build := codersdk.WorkspaceBuild{ + ID: uuid.New(), + CreatedAt: now, + UpdatedAt: now, + WorkspaceID: workspaceID, + WorkspaceName: request.Name, + WorkspaceOwnerName: user, + TemplateVersionID: templateVersionID, + Transition: codersdk.WorkspaceTransitionStart, + Status: codersdk.WorkspaceStatusRunning, + } + workspace := codersdk.Workspace{ + ID: workspaceID, + CreatedAt: now, + UpdatedAt: now, + OwnerName: user, + OrganizationID: template.OrganizationID, + OrganizationName: template.OrganizationName, + TemplateID: template.ID, + TemplateName: template.Name, + Name: request.Name, + TTLMillis: request.TTLMillis, + AutostartSchedule: request.AutostartSchedule, + LastUsedAt: now, + LatestBuild: build, + } + + s.workspacesByID[workspace.ID] = workspace + userWorkspaces, ok := s.workspaceIDsByUser[user] + if !ok { + userWorkspaces = map[string]uuid.UUID{} + s.workspaceIDsByUser[user] = userWorkspaces + } + userWorkspaces[workspace.Name] = workspace.ID + + writeJSON(w, http.StatusCreated, workspace) +} + +func (s *mockCoderServerState) handleCreateWorkspaceBuild(w http.ResponseWriter, r *http.Request, workspaceIDSegment string) { + s.mu.Lock() + defer s.mu.Unlock() + + workspaceID, err := uuid.Parse(workspaceIDSegment) + if err != nil { + writeCoderError(w, http.StatusBadRequest, fmt.Sprintf("invalid workspace id %q", workspaceIDSegment)) + return + } + + workspace, ok := s.workspacesByID[workspaceID] + if !ok { + writeCoderError(w, http.StatusNotFound, "workspace not found") + return + } + + var request codersdk.CreateWorkspaceBuildRequest + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + writeCoderError(w, http.StatusBadRequest, fmt.Sprintf("decode create workspace build request: %v", err)) + return + } + + if statusCode, shouldFail := s.failBuildTransitions[request.Transition]; shouldFail { + writeCoderError(w, statusCode, fmt.Sprintf("forced failure for transition %q", request.Transition)) + return + } + + now := time.Now().UTC() + build := codersdk.WorkspaceBuild{ + ID: uuid.New(), + CreatedAt: now, + UpdatedAt: now, + WorkspaceID: workspace.ID, + WorkspaceName: workspace.Name, + WorkspaceOwnerName: workspace.OwnerName, + TemplateVersionID: workspace.LatestBuild.TemplateVersionID, + Transition: request.Transition, + Status: statusFromTransition(request.Transition), + } + + workspace.LatestBuild = build + workspace.UpdatedAt = now + s.workspacesByID[workspace.ID] = workspace + s.buildTransitions = append(s.buildTransitions, request.Transition) + + writeJSON(w, http.StatusCreated, build) +} + +func (s *mockCoderServerState) hasTemplate(organization, templateName string) bool { + s.mu.Lock() + defer s.mu.Unlock() + + organizationTemplates, ok := s.templateIDsByOrg[organization] + if !ok { + return false + } + _, ok = organizationTemplates[templateName] + return ok +} + +func (s *mockCoderServerState) setTemplateVersionTemplateID(templateVersionID, templateID uuid.UUID) { + if templateVersionID == uuid.Nil { + panic("assertion failed: template version ID must not be nil") + } + if templateID == uuid.Nil { + panic("assertion failed: template ID must not be nil") + } + + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now().UTC() + templateIDCopy := templateID + version, ok := s.templateVersionsByID[templateVersionID] + if !ok { + version = codersdk.TemplateVersion{ + ID: templateVersionID, + OrganizationID: s.organization.ID, + CreatedAt: now, + } + } + version.TemplateID = &templateIDCopy + version.UpdatedAt = now + + s.templateVersionsByID[templateVersionID] = version +} + +func (s *mockCoderServerState) hasWorkspace(owner, workspaceName string) bool { + s.mu.Lock() + defer s.mu.Unlock() + + userWorkspaces, ok := s.workspaceIDsByUser[owner] + if !ok { + return false + } + _, ok = userWorkspaces[workspaceName] + return ok +} + +func (s *mockCoderServerState) workspaceLatestBuildTemplateVersionID(owner, workspaceName string) (uuid.UUID, bool) { + s.mu.Lock() + defer s.mu.Unlock() + + userWorkspaces, ok := s.workspaceIDsByUser[owner] + if !ok { + return uuid.Nil, false + } + + workspaceID, ok := userWorkspaces[workspaceName] + if !ok { + return uuid.Nil, false + } + + workspace, ok := s.workspacesByID[workspaceID] + if !ok { + return uuid.Nil, false + } + + return workspace.LatestBuild.TemplateVersionID, true +} + +func (s *mockCoderServerState) buildTransitionsSnapshot() []codersdk.WorkspaceTransition { + s.mu.Lock() + defer s.mu.Unlock() + + transitions := make([]codersdk.WorkspaceTransition, len(s.buildTransitions)) + copy(transitions, s.buildTransitions) + return transitions +} + +func (s *mockCoderServerState) setBuildTransitionFailure(transition codersdk.WorkspaceTransition, statusCode int) { + s.mu.Lock() + defer s.mu.Unlock() + + if transition == "" { + panic("assertion failed: transition must not be empty") } - if template.Name != "starter-template" { - t.Fatalf("expected starter-template, got %q", template.Name) + if statusCode < http.StatusBadRequest || statusCode > http.StatusNetworkAuthenticationRequired { + panic(fmt.Sprintf("assertion failed: invalid HTTP status code %d", statusCode)) } + + s.failBuildTransitions[transition] = statusCode } -func TestTemplateStorageGetNotFound(t *testing.T) { +func newTestClientProvider(t *testing.T, serverURL string) coder.ClientProvider { t.Helper() - templateStorage := NewTemplateStorage() - ctx := genericapirequest.WithNamespace(context.Background(), "default") + parsedURL, err := url.Parse(serverURL) + if err != nil { + t.Fatalf("parse mock server URL %q: %v", serverURL, err) + } + + client := codersdk.New(parsedURL) + client.SetSessionToken("test-session-token") - _, err := templateStorage.Get(ctx, "does-not-exist", nil) - if !apierrors.IsNotFound(err) { - t.Fatalf("expected NotFound error, got %v", err) + return &coder.StaticClientProvider{Client: client, Namespace: "control-plane"} +} + +func namespacedContext(namespace string) context.Context { + return genericapirequest.WithNamespace(context.Background(), namespace) +} + +func containsTransition(transitions []codersdk.WorkspaceTransition, transition codersdk.WorkspaceTransition) bool { + for _, got := range transitions { + if got == transition { + return true + } + } + return false +} + +func statusFromTransition(transition codersdk.WorkspaceTransition) codersdk.WorkspaceStatus { + switch transition { + case codersdk.WorkspaceTransitionStart: + return codersdk.WorkspaceStatusRunning + case codersdk.WorkspaceTransitionStop: + return codersdk.WorkspaceStatusStopped + case codersdk.WorkspaceTransitionDelete: + return codersdk.WorkspaceStatusDeleted + default: + return codersdk.WorkspaceStatusPending + } +} + +func splitPath(path string) []string { + trimmed := strings.Trim(path, "/") + if trimmed == "" { + return nil + } + + return strings.Split(trimmed, "/") +} + +func hasSegments(segments []string, expected ...string) bool { + if len(segments) < len(expected) { + return false + } + + for i, segment := range expected { + if segments[i] != segment { + return false + } } + + return true +} + +func writeJSON(w http.ResponseWriter, statusCode int, payload any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + _ = json.NewEncoder(w).Encode(payload) +} + +func writeCoderError(w http.ResponseWriter, statusCode int, message string) { + writeJSON(w, statusCode, codersdk.Response{Message: message}) } diff --git a/internal/aggregated/storage/template.go b/internal/aggregated/storage/template.go index 3630104b..8c94e211 100644 --- a/internal/aggregated/storage/template.go +++ b/internal/aggregated/storage/template.go @@ -3,102 +3,48 @@ package storage import ( "context" "fmt" - "sort" - "sync" - "time" apierrors "k8s.io/apimachinery/pkg/api/errors" metainternalversion "k8s.io/apimachinery/pkg/apis/meta/internalversion" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" - genericapirequest "k8s.io/apiserver/pkg/endpoints/request" "k8s.io/apiserver/pkg/registry/rest" aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" + "github.com/coder/coder-k8s/internal/aggregated/coder" + "github.com/coder/coder-k8s/internal/aggregated/convert" + "github.com/coder/coder/v2/codersdk" ) var ( _ rest.Storage = (*TemplateStorage)(nil) _ rest.Getter = (*TemplateStorage)(nil) _ rest.Lister = (*TemplateStorage)(nil) + _ rest.Creater = (*TemplateStorage)(nil) //nolint:misspell // Kubernetes rest interface name is Creater. _ rest.Updater = (*TemplateStorage)(nil) _ rest.GracefulDeleter = (*TemplateStorage)(nil) _ rest.Scoper = (*TemplateStorage)(nil) _ rest.SingularNameProvider = (*TemplateStorage)(nil) ) -// TemplateStorage provides hardcoded CoderTemplate objects. +// TemplateStorage provides codersdk-backed CoderTemplate objects. type TemplateStorage struct { - mu sync.RWMutex + provider coder.ClientProvider tableConvertor rest.TableConvertor - templates map[string]*aggregationv1alpha1.CoderTemplate } -// NewTemplateStorage builds hardcoded storage for CoderTemplate resources. -func NewTemplateStorage() *TemplateStorage { - starterDeadline := metav1.NewTime(time.Date(2030, time.January, 4, 18, 0, 0, 0, time.UTC)) - platformDeadline := metav1.NewTime(time.Date(2030, time.January, 5, 18, 0, 0, 0, time.UTC)) - docsDeadline := metav1.NewTime(time.Date(2030, time.January, 6, 18, 0, 0, 0, time.UTC)) +// NewTemplateStorage builds codersdk-backed storage for CoderTemplate resources. +func NewTemplateStorage(provider coder.ClientProvider) *TemplateStorage { + if provider == nil { + panic("assertion failed: template client provider must not be nil") + } return &TemplateStorage{ + provider: provider, tableConvertor: rest.NewDefaultTableConvertor(aggregationv1alpha1.Resource("codertemplates")), - templates: map[string]*aggregationv1alpha1.CoderTemplate{ - templateKey("default", "starter-template"): { - TypeMeta: metav1.TypeMeta{ - Kind: "CoderTemplate", - APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), - }, - ObjectMeta: metav1.ObjectMeta{ - Name: "starter-template", - Namespace: "default", - ResourceVersion: "1", - Generation: 1, - }, - Spec: aggregationv1alpha1.CoderTemplateSpec{Running: true}, - Status: aggregationv1alpha1.CoderTemplateStatus{ - AutoShutdown: &starterDeadline, - }, - }, - templateKey("default", "platform-template"): { - TypeMeta: metav1.TypeMeta{ - Kind: "CoderTemplate", - APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), - }, - ObjectMeta: metav1.ObjectMeta{ - Name: "platform-template", - Namespace: "default", - ResourceVersion: "1", - Generation: 1, - }, - Spec: aggregationv1alpha1.CoderTemplateSpec{Running: false}, - Status: aggregationv1alpha1.CoderTemplateStatus{ - AutoShutdown: &platformDeadline, - }, - }, - templateKey("sandbox", "docs-template"): { - TypeMeta: metav1.TypeMeta{ - Kind: "CoderTemplate", - APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), - }, - ObjectMeta: metav1.ObjectMeta{ - Name: "docs-template", - Namespace: "sandbox", - ResourceVersion: "1", - Generation: 1, - }, - Spec: aggregationv1alpha1.CoderTemplateSpec{Running: true}, - Status: aggregationv1alpha1.CoderTemplateStatus{ - AutoShutdown: &docsDeadline, - }, - }, - }, } } -func templateKey(namespace, name string) string { - return namespace + "/" + name -} - // New returns an empty CoderTemplate object. func (s *TemplateStorage) New() runtime.Object { return &aggregationv1alpha1.CoderTemplate{} @@ -122,7 +68,7 @@ func (s *TemplateStorage) NewList() runtime.Object { return &aggregationv1alpha1.CoderTemplateList{} } -// Get returns a hardcoded CoderTemplate by name. +// Get fetches a CoderTemplate by organization and template name. func (s *TemplateStorage) Get(ctx context.Context, name string, _ *metav1.GetOptions) (runtime.Object, error) { if s == nil { return nil, fmt.Errorf("assertion failed: template storage must not be nil") @@ -134,30 +80,35 @@ func (s *TemplateStorage) Get(ctx context.Context, name string, _ *metav1.GetOpt return nil, fmt.Errorf("assertion failed: template name must not be empty") } - namespace := genericapirequest.NamespaceValue(ctx) + namespace, badNamespaceErr := requiredNamespaceFromRequestContext(ctx) + if badNamespaceErr != nil { + return nil, badNamespaceErr + } - s.mu.RLock() - defer s.mu.RUnlock() + orgName, templateName, err := coder.ParseTemplateName(name) + if err != nil { + return nil, apierrors.NewBadRequest(fmt.Sprintf("invalid template name %q: %v", name, err)) + } - if namespace != "" { - template, ok := s.templates[templateKey(namespace, name)] - if !ok { - return nil, apierrors.NewNotFound(aggregationv1alpha1.Resource("codertemplates"), name) - } - return template.DeepCopy(), nil + sdk, err := s.clientForNamespace(ctx, namespace) + if err != nil { + return nil, wrapClientError(err) } - template, found, ambiguous := s.findTemplateByNameLocked(name) - if ambiguous { - return nil, apierrors.NewBadRequest(fmt.Sprintf("template name %q is ambiguous across namespaces; specify namespace", name)) + org, err := sdk.OrganizationByName(ctx, orgName) + if err != nil { + return nil, coder.MapCoderError(err, aggregationv1alpha1.Resource("codertemplates"), name) } - if !found { - return nil, apierrors.NewNotFound(aggregationv1alpha1.Resource("codertemplates"), name) + + template, err := sdk.TemplateByName(ctx, org.ID, templateName) + if err != nil { + return nil, coder.MapCoderError(err, aggregationv1alpha1.Resource("codertemplates"), name) } - return template.DeepCopy(), nil + + return convert.TemplateToK8s(namespace, template), nil } -// List returns hardcoded CoderTemplate objects. +// List fetches CoderTemplate objects from codersdk. func (s *TemplateStorage) List(ctx context.Context, _ *metainternalversion.ListOptions) (runtime.Object, error) { if s == nil { return nil, fmt.Errorf("assertion failed: template storage must not be nil") @@ -166,36 +117,42 @@ func (s *TemplateStorage) List(ctx context.Context, _ *metainternalversion.ListO return nil, fmt.Errorf("assertion failed: context must not be nil") } - namespace := genericapirequest.NamespaceValue(ctx) + namespace, badNamespaceErr := namespaceFromRequestContext(ctx) + if badNamespaceErr != nil { + return nil, badNamespaceErr + } + + responseNamespace, responseNamespaceErr := namespaceForListConversion(namespace, s.provider) + if responseNamespaceErr != nil { + return nil, responseNamespaceErr + } + + sdk, err := s.clientForNamespace(ctx, namespace) + if err != nil { + return nil, wrapClientError(err) + } + + templates, err := sdk.Templates(ctx, codersdk.TemplateFilter{}) + if err != nil { + return nil, coder.MapCoderError(err, aggregationv1alpha1.Resource("codertemplates"), "") + } + list := &aggregationv1alpha1.CoderTemplateList{ TypeMeta: metav1.TypeMeta{ Kind: "CoderTemplateList", APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), }, - Items: make([]aggregationv1alpha1.CoderTemplate, 0), - } - - s.mu.RLock() - defer s.mu.RUnlock() - - keys := make([]string, 0, len(s.templates)) - for key := range s.templates { - keys = append(keys, key) + Items: make([]aggregationv1alpha1.CoderTemplate, 0, len(templates)), } - sort.Strings(keys) - for _, key := range keys { - template := s.templates[key] - if namespace != "" && template.Namespace != namespace { - continue - } - list.Items = append(list.Items, *template.DeepCopy()) + for _, template := range templates { + list.Items = append(list.Items, *convert.TemplateToK8s(responseNamespace, template)) } return list, nil } -// Create inserts a CoderTemplate into the in-memory store. +// Create creates a CoderTemplate through codersdk. func (s *TemplateStorage) Create( ctx context.Context, obj runtime.Object, @@ -212,56 +169,72 @@ func (s *TemplateStorage) Create( return nil, fmt.Errorf("assertion failed: object must not be nil") } - template, ok := obj.(*aggregationv1alpha1.CoderTemplate) + templateObj, ok := obj.(*aggregationv1alpha1.CoderTemplate) if !ok { return nil, apierrors.NewBadRequest(fmt.Sprintf("expected *CoderTemplate, got %T", obj)) } + if createValidation != nil { + if err := createValidation(ctx, obj); err != nil { + return nil, err + } + } + if templateObj.Name == "" { + return nil, apierrors.NewBadRequest("metadata.name must not be empty") + } - candidate := template.DeepCopy() - if candidate.Name == "" { - return nil, apierrors.NewBadRequest("metadata.name is required") + namespace, badNamespaceErr := requiredNamespaceFromRequestContext(ctx) + if badNamespaceErr != nil { + return nil, badNamespaceErr + } + if templateObj.Namespace != "" && templateObj.Namespace != namespace { + return nil, apierrors.NewBadRequest( + fmt.Sprintf("metadata.namespace %q must match request namespace %q", templateObj.Namespace, namespace), + ) } - namespace, err := resolveWriteNamespace(ctx, candidate.Namespace) + orgName, templateName, err := coder.ParseTemplateName(templateObj.Name) if err != nil { - return nil, err + return nil, apierrors.NewBadRequest(fmt.Sprintf("invalid template name %q: %v", templateObj.Name, err)) + } + if templateObj.Spec.Organization != orgName { + return nil, apierrors.NewBadRequest( + fmt.Sprintf( + "spec.organization %q must match organization %q parsed from metadata.name", + templateObj.Spec.Organization, + orgName, + ), + ) } - candidate.Namespace = namespace - ensureTemplateTypeMeta(candidate) - if candidate.Generation == 0 { - candidate.Generation = 1 - } - if candidate.CreationTimestamp.IsZero() { - candidate.CreationTimestamp = metav1.Now() + sdk, err := s.clientForNamespace(ctx, namespace) + if err != nil { + return nil, wrapClientError(err) } - candidate.ResourceVersion = "1" - if createValidation != nil { - if err := createValidation(ctx, candidate); err != nil { - return nil, err - } + org, err := sdk.OrganizationByName(ctx, orgName) + if err != nil { + return nil, coder.MapCoderError(err, aggregationv1alpha1.Resource("codertemplates"), templateObj.Name) } - key := templateKey(candidate.Namespace, candidate.Name) - - s.mu.Lock() - defer s.mu.Unlock() + request, err := convert.TemplateCreateRequestFromK8s(templateObj, templateName) + if err != nil { + return nil, apierrors.NewBadRequest(err.Error()) + } - if _, exists := s.templates[key]; exists { - return nil, apierrors.NewAlreadyExists(aggregationv1alpha1.Resource("codertemplates"), candidate.Name) + createdTemplate, err := sdk.CreateTemplate(ctx, org.ID, request) + if err != nil { + return nil, coder.MapCoderError(err, aggregationv1alpha1.Resource("codertemplates"), templateObj.Name) } - s.templates[key] = candidate.DeepCopy() - return candidate.DeepCopy(), nil + return convert.TemplateToK8s(namespace, createdTemplate), nil } -// Update modifies an existing CoderTemplate in the in-memory store. +// Update applies a legacy-compatible template update. func (s *TemplateStorage) Update( ctx context.Context, name string, objInfo rest.UpdatedObjectInfo, - createValidation rest.ValidateObjectFunc, + _ rest.ValidateObjectFunc, updateValidation rest.ValidateObjectUpdateFunc, forceAllowCreate bool, _ *metav1.UpdateOptions, @@ -278,124 +251,91 @@ func (s *TemplateStorage) Update( if objInfo == nil { return nil, false, fmt.Errorf("assertion failed: updated object info must not be nil") } - - namespace := genericapirequest.NamespaceValue(ctx) - if namespace == "" { - return nil, false, apierrors.NewBadRequest("namespace is required") + if forceAllowCreate { + return nil, false, apierrors.NewMethodNotSupported( + aggregationv1alpha1.Resource("codertemplates"), + "create on update", + ) } - key := templateKey(namespace, name) - - s.mu.Lock() - defer s.mu.Unlock() - - existing, exists := s.templates[key] - if !exists { - if !forceAllowCreate { - return nil, false, apierrors.NewNotFound(aggregationv1alpha1.Resource("codertemplates"), name) - } - - createdObj, err := objInfo.UpdatedObject(ctx, &aggregationv1alpha1.CoderTemplate{}) - if err != nil { - return nil, false, err - } - createdTemplate, ok := createdObj.(*aggregationv1alpha1.CoderTemplate) - if !ok { - return nil, false, apierrors.NewBadRequest(fmt.Sprintf("expected *CoderTemplate, got %T", createdObj)) - } - - candidate := createdTemplate.DeepCopy() - if candidate.Name == "" { - candidate.Name = name - } - if candidate.Name != name { - return nil, false, apierrors.NewBadRequest(fmt.Sprintf("metadata.name %q must match request name %q", candidate.Name, name)) - } - if candidate.Namespace == "" { - candidate.Namespace = namespace - } - if candidate.Namespace != namespace { - return nil, false, apierrors.NewBadRequest( - fmt.Sprintf("metadata.namespace %q must match request namespace %q", candidate.Namespace, namespace), - ) - } - - ensureTemplateTypeMeta(candidate) - if candidate.Generation == 0 { - candidate.Generation = 1 - } - if candidate.CreationTimestamp.IsZero() { - candidate.CreationTimestamp = metav1.Now() - } - candidate.ResourceVersion = "1" - - if createValidation != nil { - if err := createValidation(ctx, candidate); err != nil { - return nil, false, err - } - } + currentObj, err := s.Get(ctx, name, nil) + if err != nil { + return nil, false, err + } - s.templates[key] = candidate.DeepCopy() - return candidate.DeepCopy(), true, nil + currentObjForUpdate := currentObj.DeepCopyObject() + if currentObjForUpdate == nil { + return nil, false, fmt.Errorf("assertion failed: current template object deep copy must not be nil") } - updatedObj, err := objInfo.UpdatedObject(ctx, existing.DeepCopy()) + updatedObj, err := objInfo.UpdatedObject(ctx, currentObjForUpdate) if err != nil { return nil, false, err } + if updatedObj == nil { + return nil, false, fmt.Errorf("assertion failed: updated template object must not be nil") + } updatedTemplate, ok := updatedObj.(*aggregationv1alpha1.CoderTemplate) if !ok { - return nil, false, apierrors.NewBadRequest(fmt.Sprintf("expected *CoderTemplate, got %T", updatedObj)) + return nil, false, fmt.Errorf("assertion failed: expected *CoderTemplate, got %T", updatedObj) } - - candidate := updatedTemplate.DeepCopy() - if candidate.Name == "" { - candidate.Name = name + currentTemplate, ok := currentObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + return nil, false, fmt.Errorf("assertion failed: expected *CoderTemplate, got %T", currentObj) } - if candidate.Name != name { - return nil, false, apierrors.NewBadRequest(fmt.Sprintf("metadata.name %q must match request name %q", candidate.Name, name)) + + namespace, badNamespaceErr := requiredNamespaceFromRequestContext(ctx) + if badNamespaceErr != nil { + return nil, false, badNamespaceErr } - if candidate.Namespace == "" { - candidate.Namespace = namespace + if updatedTemplate.Name != "" && updatedTemplate.Name != name { + return nil, false, apierrors.NewBadRequest( + fmt.Sprintf("updated object metadata.name %q must match request name %q", updatedTemplate.Name, name), + ) } - if candidate.Namespace != namespace { + if updatedTemplate.Namespace != "" && updatedTemplate.Namespace != namespace { return nil, false, apierrors.NewBadRequest( - fmt.Sprintf("metadata.namespace %q must match request namespace %q", candidate.Namespace, namespace), + fmt.Sprintf("metadata.namespace %q does not match request namespace %q", updatedTemplate.Namespace, namespace), ) } - - if candidate.ResourceVersion == "" { + if updatedTemplate.ResourceVersion == "" { return nil, false, apierrors.NewBadRequest("metadata.resourceVersion is required for update") } - if candidate.ResourceVersion != existing.ResourceVersion { + if updatedTemplate.ResourceVersion != currentTemplate.ResourceVersion { return nil, false, apierrors.NewConflict( aggregationv1alpha1.Resource("codertemplates"), name, - fmt.Errorf("resourceVersion %q does not match current value %q", candidate.ResourceVersion, existing.ResourceVersion), + fmt.Errorf( + "resource version mismatch: got %q, current is %q", + updatedTemplate.ResourceVersion, + currentTemplate.ResourceVersion, + ), ) } - - candidate.Status = existing.Status - candidate.CreationTimestamp = existing.CreationTimestamp - candidate.Generation = existing.Generation + 1 - candidateFinalResourceVersion, err := incrementResourceVersion(existing.ResourceVersion) - if err != nil { - return nil, false, err - } - candidate.ResourceVersion = candidateFinalResourceVersion - ensureTemplateTypeMeta(candidate) - if updateValidation != nil { - if err := updateValidation(ctx, candidate, existing); err != nil { + if err := updateValidation(ctx, updatedTemplate, currentTemplate); err != nil { return nil, false, err } } - s.templates[key] = candidate.DeepCopy() - return candidate.DeepCopy(), false, nil + // Template updates via codersdk are currently limited. The legacy spec.running + // field remains for compatibility with in-repo callers and is a no-op in the + // Coder backend. Reject updates to all other spec fields to avoid drift between + // accepted update payloads and persisted backend state. + if updatedTemplate.Spec.Organization != currentTemplate.Spec.Organization || + (updatedTemplate.Spec.VersionID != "" && updatedTemplate.Spec.VersionID != currentTemplate.Spec.VersionID) || + (updatedTemplate.Spec.DisplayName != "" && updatedTemplate.Spec.DisplayName != currentTemplate.Spec.DisplayName) || + (updatedTemplate.Spec.Description != "" && updatedTemplate.Spec.Description != currentTemplate.Spec.Description) || + (updatedTemplate.Spec.Icon != "" && updatedTemplate.Spec.Icon != currentTemplate.Spec.Icon) { + return nil, false, apierrors.NewBadRequest( + "template update only supports changing spec.running; other spec fields are immutable", + ) + } + + return currentTemplate, false, nil } -// Delete removes a CoderTemplate from the in-memory store. +// Delete deletes a CoderTemplate through codersdk. func (s *TemplateStorage) Delete( ctx context.Context, name string, @@ -412,49 +352,42 @@ func (s *TemplateStorage) Delete( return nil, false, fmt.Errorf("assertion failed: template name must not be empty") } - namespace := genericapirequest.NamespaceValue(ctx) - - s.mu.Lock() - defer s.mu.Unlock() - - var ( - key string - template *aggregationv1alpha1.CoderTemplate - ) - if namespace != "" { - key = templateKey(namespace, name) - template = s.templates[key] - } else { - matchedKeys := make([]string, 0) - for candidateKey, candidateTemplate := range s.templates { - if candidateTemplate.Name == name { - matchedKeys = append(matchedKeys, candidateKey) - } - } - if len(matchedKeys) > 1 { - return nil, false, apierrors.NewBadRequest( - fmt.Sprintf("template name %q is ambiguous across namespaces; specify namespace", name), - ) - } - if len(matchedKeys) == 1 { - key = matchedKeys[0] - template = s.templates[key] - } + namespace, badNamespaceErr := requiredNamespaceFromRequestContext(ctx) + if badNamespaceErr != nil { + return nil, false, badNamespaceErr } - if template == nil { - return nil, false, apierrors.NewNotFound(aggregationv1alpha1.Resource("codertemplates"), name) + orgName, templateName, err := coder.ParseTemplateName(name) + if err != nil { + return nil, false, apierrors.NewBadRequest(fmt.Sprintf("invalid template name %q: %v", name, err)) + } + + sdk, err := s.clientForNamespace(ctx, namespace) + if err != nil { + return nil, false, wrapClientError(err) + } + + org, err := sdk.OrganizationByName(ctx, orgName) + if err != nil { + return nil, false, coder.MapCoderError(err, aggregationv1alpha1.Resource("codertemplates"), name) + } + + template, err := sdk.TemplateByName(ctx, org.ID, templateName) + if err != nil { + return nil, false, coder.MapCoderError(err, aggregationv1alpha1.Resource("codertemplates"), name) } if deleteValidation != nil { - if err := deleteValidation(ctx, template.DeepCopy()); err != nil { - return nil, false, err + if validationErr := deleteValidation(ctx, convert.TemplateToK8s(namespace, template)); validationErr != nil { + return nil, false, validationErr } } - deleted := template.DeepCopy() - delete(s.templates, key) - return deleted, true, nil + if err := sdk.DeleteTemplate(ctx, template.ID); err != nil { + return nil, false, coder.MapCoderError(err, aggregationv1alpha1.Resource("codertemplates"), name) + } + + return &metav1.Status{Status: metav1.StatusSuccess}, true, nil } // ConvertToTable converts a template object or list into kubectl table output. @@ -469,32 +402,18 @@ func (s *TemplateStorage) ConvertToTable(ctx context.Context, object, tableOptio return s.tableConvertor.ConvertToTable(ctx, object, tableOptions) } -func ensureTemplateTypeMeta(template *aggregationv1alpha1.CoderTemplate) { - if template == nil { - panic("assertion failed: template must not be nil") - } - template.TypeMeta = metav1.TypeMeta{ - Kind: "CoderTemplate", - APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), +func (s *TemplateStorage) clientForNamespace(ctx context.Context, namespace string) (*codersdk.Client, error) { + if s.provider == nil { + return nil, fmt.Errorf("assertion failed: template client provider must not be nil") } -} -func (s *TemplateStorage) findTemplateByNameLocked(name string) (*aggregationv1alpha1.CoderTemplate, bool, bool) { - matchedKeys := make([]string, 0) - for key, template := range s.templates { - if template.Name == name { - matchedKeys = append(matchedKeys, key) - } - } - if len(matchedKeys) == 0 { - return nil, false, false - } - if len(matchedKeys) > 1 { - return nil, false, true + sdk, err := s.provider.ClientForNamespace(ctx, namespace) + if err != nil { + return nil, fmt.Errorf("resolve codersdk client for namespace %q: %w", namespace, err) } - template := s.templates[matchedKeys[0]] - if template == nil { - return nil, false, false + if sdk == nil { + return nil, fmt.Errorf("assertion failed: template client provider returned nil codersdk client") } - return template, true, false + + return sdk, nil } diff --git a/internal/aggregated/storage/template_test.go b/internal/aggregated/storage/template_test.go deleted file mode 100644 index 73d2bec7..00000000 --- a/internal/aggregated/storage/template_test.go +++ /dev/null @@ -1,231 +0,0 @@ -package storage - -import ( - "context" - "testing" - "time" - - apierrors "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - genericapirequest "k8s.io/apiserver/pkg/endpoints/request" - "k8s.io/apiserver/pkg/registry/rest" - - aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" -) - -func TestTemplateStorageCRUDLifecycle(t *testing.T) { - t.Helper() - - templateStorage := NewTemplateStorage() - ctx := genericapirequest.WithNamespace(context.Background(), "default") - - createdObj, err := templateStorage.Create(ctx, &aggregationv1alpha1.CoderTemplate{ - ObjectMeta: metav1.ObjectMeta{Name: "unit-template"}, - Spec: aggregationv1alpha1.CoderTemplateSpec{Running: true}, - }, rest.ValidateAllObjectFunc, nil) - if err != nil { - t.Fatalf("create template: %v", err) - } - - created, ok := createdObj.(*aggregationv1alpha1.CoderTemplate) - if !ok { - t.Fatalf("expected *CoderTemplate from create, got %T", createdObj) - } - if created.Namespace != "default" { - t.Fatalf("expected namespace default, got %q", created.Namespace) - } - if created.ResourceVersion != "1" { - t.Fatalf("expected resourceVersion 1, got %q", created.ResourceVersion) - } - if created.Generation != 1 { - t.Fatalf("expected generation 1, got %d", created.Generation) - } - - toUpdate := created.DeepCopy() - toUpdate.Spec.Running = false - toUpdate.ResourceVersion = created.ResourceVersion - updatedObj, createdOnUpdate, err := templateStorage.Update( - ctx, - toUpdate.Name, - rest.DefaultUpdatedObjectInfo(toUpdate), - nil, - nil, - false, - nil, - ) - if err != nil { - t.Fatalf("update template: %v", err) - } - if createdOnUpdate { - t.Fatal("expected update of existing template, got createdOnUpdate=true") - } - - updated, ok := updatedObj.(*aggregationv1alpha1.CoderTemplate) - if !ok { - t.Fatalf("expected *CoderTemplate from update, got %T", updatedObj) - } - if updated.Spec.Running { - t.Fatalf("expected running=false after update, got %+v", updated.Spec) - } - if updated.ResourceVersion == created.ResourceVersion { - t.Fatalf("expected resourceVersion to change, got %q", updated.ResourceVersion) - } - if updated.Generation != created.Generation+1 { - t.Fatalf("expected generation increment to %d, got %d", created.Generation+1, updated.Generation) - } - - deletedObj, deletedNow, err := templateStorage.Delete(ctx, created.Name, nil, nil) - if err != nil { - t.Fatalf("delete template: %v", err) - } - if !deletedNow { - t.Fatal("expected immediate delete") - } - if _, ok := deletedObj.(*aggregationv1alpha1.CoderTemplate); !ok { - t.Fatalf("expected *CoderTemplate from delete, got %T", deletedObj) - } - - _, err = templateStorage.Get(ctx, created.Name, nil) - if !apierrors.IsNotFound(err) { - t.Fatalf("expected NotFound after delete, got %v", err) - } -} - -func TestTemplateStorageCreateAlreadyExists(t *testing.T) { - t.Helper() - - templateStorage := NewTemplateStorage() - ctx := genericapirequest.WithNamespace(context.Background(), "default") - - _, err := templateStorage.Create(ctx, &aggregationv1alpha1.CoderTemplate{ - ObjectMeta: metav1.ObjectMeta{Name: "starter-template"}, - Spec: aggregationv1alpha1.CoderTemplateSpec{Running: true}, - }, nil, nil) - if !apierrors.IsAlreadyExists(err) { - t.Fatalf("expected AlreadyExists error, got %v", err) - } -} - -func TestTemplateStorageUpdateRejectsNamespaceChange(t *testing.T) { - t.Helper() - - templateStorage := NewTemplateStorage() - ctx := genericapirequest.WithNamespace(context.Background(), "default") - - currentObj, err := templateStorage.Get(ctx, "starter-template", nil) - if err != nil { - t.Fatalf("get template: %v", err) - } - current := currentObj.(*aggregationv1alpha1.CoderTemplate) - - modified := current.DeepCopy() - modified.Namespace = "sandbox" - modified.ResourceVersion = current.ResourceVersion - - _, _, err = templateStorage.Update( - ctx, - modified.Name, - rest.DefaultUpdatedObjectInfo(modified), - nil, - nil, - false, - nil, - ) - if !apierrors.IsBadRequest(err) { - t.Fatalf("expected BadRequest for namespace mismatch, got %v", err) - } -} - -func TestTemplateStorageUpdateIgnoresStatusWrites(t *testing.T) { - t.Helper() - - templateStorage := NewTemplateStorage() - ctx := genericapirequest.WithNamespace(context.Background(), "default") - - currentObj, err := templateStorage.Get(ctx, "starter-template", nil) - if err != nil { - t.Fatalf("get template: %v", err) - } - current := currentObj.(*aggregationv1alpha1.CoderTemplate) - if current.Status.AutoShutdown == nil { - t.Fatal("expected seeded template status autoShutdown") - } - - modified := current.DeepCopy() - modified.Spec.Running = !current.Spec.Running - modified.ResourceVersion = current.ResourceVersion - overrideDeadline := metav1.NewTime(time.Date(2040, time.January, 1, 0, 0, 0, 0, time.UTC)) - modified.Status.AutoShutdown = &overrideDeadline - - updatedObj, _, err := templateStorage.Update( - ctx, - modified.Name, - rest.DefaultUpdatedObjectInfo(modified), - nil, - nil, - false, - nil, - ) - if err != nil { - t.Fatalf("update template: %v", err) - } - - updated := updatedObj.(*aggregationv1alpha1.CoderTemplate) - if updated.Status.AutoShutdown == nil { - t.Fatal("expected status autoShutdown to remain present") - } - if !updated.Status.AutoShutdown.Equal(current.Status.AutoShutdown) { - t.Fatalf("expected status to remain unchanged, got %s want %s", updated.Status.AutoShutdown, current.Status.AutoShutdown) - } -} - -func TestTemplateStorageDeleteAmbiguousWithoutNamespace(t *testing.T) { - t.Helper() - - templateStorage := NewTemplateStorage() - - _, err := templateStorage.Create( - genericapirequest.WithNamespace(context.Background(), "sandbox"), - &aggregationv1alpha1.CoderTemplate{ObjectMeta: metav1.ObjectMeta{Name: "starter-template"}}, - nil, - nil, - ) - if err != nil { - t.Fatalf("seed same-name template in sandbox namespace: %v", err) - } - - _, _, err = templateStorage.Delete(context.Background(), "starter-template", nil, nil) - if !apierrors.IsBadRequest(err) { - t.Fatalf("expected BadRequest for ambiguous delete, got %v", err) - } -} - -func TestTemplateStorageUpdateRequiresResourceVersion(t *testing.T) { - t.Helper() - - templateStorage := NewTemplateStorage() - ctx := genericapirequest.WithNamespace(context.Background(), "default") - - currentObj, err := templateStorage.Get(ctx, "starter-template", nil) - if err != nil { - t.Fatalf("get template: %v", err) - } - current := currentObj.(*aggregationv1alpha1.CoderTemplate) - - modified := current.DeepCopy() - modified.Spec.Running = !current.Spec.Running - modified.ResourceVersion = "" - - _, _, err = templateStorage.Update( - ctx, - modified.Name, - rest.DefaultUpdatedObjectInfo(modified), - nil, - nil, - false, - nil, - ) - if !apierrors.IsBadRequest(err) { - t.Fatalf("expected BadRequest when resourceVersion is missing, got %v", err) - } -} diff --git a/internal/aggregated/storage/workspace.go b/internal/aggregated/storage/workspace.go index 522d1d1e..a3471559 100644 --- a/internal/aggregated/storage/workspace.go +++ b/internal/aggregated/storage/workspace.go @@ -1,13 +1,10 @@ -// Package storage provides hardcoded in-memory storage implementations for aggregated API resources. package storage import ( "context" "fmt" - "sort" - "sync" - "time" + "github.com/google/uuid" apierrors "k8s.io/apimachinery/pkg/api/errors" metainternalversion "k8s.io/apimachinery/pkg/apis/meta/internalversion" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -16,90 +13,40 @@ import ( "k8s.io/apiserver/pkg/registry/rest" aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" + "github.com/coder/coder-k8s/internal/aggregated/coder" + "github.com/coder/coder-k8s/internal/aggregated/convert" + "github.com/coder/coder/v2/codersdk" ) var ( _ rest.Storage = (*WorkspaceStorage)(nil) _ rest.Getter = (*WorkspaceStorage)(nil) _ rest.Lister = (*WorkspaceStorage)(nil) + _ rest.Creater = (*WorkspaceStorage)(nil) //nolint:misspell // Kubernetes rest interface name is Creater. _ rest.Updater = (*WorkspaceStorage)(nil) _ rest.GracefulDeleter = (*WorkspaceStorage)(nil) _ rest.Scoper = (*WorkspaceStorage)(nil) _ rest.SingularNameProvider = (*WorkspaceStorage)(nil) ) -// WorkspaceStorage provides hardcoded CoderWorkspace objects. +// WorkspaceStorage provides codersdk-backed CoderWorkspace objects. type WorkspaceStorage struct { - mu sync.RWMutex + provider coder.ClientProvider tableConvertor rest.TableConvertor - workspaces map[string]*aggregationv1alpha1.CoderWorkspace } -// NewWorkspaceStorage builds hardcoded storage for CoderWorkspace resources. -func NewWorkspaceStorage() *WorkspaceStorage { - workspaceDeadline := metav1.NewTime(time.Date(2030, time.January, 1, 18, 0, 0, 0, time.UTC)) - stagingDeadline := metav1.NewTime(time.Date(2030, time.January, 2, 18, 0, 0, 0, time.UTC)) - sandboxDeadline := metav1.NewTime(time.Date(2030, time.January, 3, 18, 0, 0, 0, time.UTC)) +// NewWorkspaceStorage builds codersdk-backed storage for CoderWorkspace resources. +func NewWorkspaceStorage(provider coder.ClientProvider) *WorkspaceStorage { + if provider == nil { + panic("assertion failed: workspace client provider must not be nil") + } return &WorkspaceStorage{ + provider: provider, tableConvertor: rest.NewDefaultTableConvertor(aggregationv1alpha1.Resource("coderworkspaces")), - workspaces: map[string]*aggregationv1alpha1.CoderWorkspace{ - workspaceKey("default", "dev-workspace"): { - TypeMeta: metav1.TypeMeta{ - Kind: "CoderWorkspace", - APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), - }, - ObjectMeta: metav1.ObjectMeta{ - Name: "dev-workspace", - Namespace: "default", - ResourceVersion: "1", - Generation: 1, - }, - Spec: aggregationv1alpha1.CoderWorkspaceSpec{Running: true}, - Status: aggregationv1alpha1.CoderWorkspaceStatus{ - AutoShutdown: &workspaceDeadline, - }, - }, - workspaceKey("default", "staging-workspace"): { - TypeMeta: metav1.TypeMeta{ - Kind: "CoderWorkspace", - APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), - }, - ObjectMeta: metav1.ObjectMeta{ - Name: "staging-workspace", - Namespace: "default", - ResourceVersion: "1", - Generation: 1, - }, - Spec: aggregationv1alpha1.CoderWorkspaceSpec{Running: false}, - Status: aggregationv1alpha1.CoderWorkspaceStatus{ - AutoShutdown: &stagingDeadline, - }, - }, - workspaceKey("sandbox", "sandbox-workspace"): { - TypeMeta: metav1.TypeMeta{ - Kind: "CoderWorkspace", - APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), - }, - ObjectMeta: metav1.ObjectMeta{ - Name: "sandbox-workspace", - Namespace: "sandbox", - ResourceVersion: "1", - Generation: 1, - }, - Spec: aggregationv1alpha1.CoderWorkspaceSpec{Running: true}, - Status: aggregationv1alpha1.CoderWorkspaceStatus{ - AutoShutdown: &sandboxDeadline, - }, - }, - }, } } -func workspaceKey(namespace, name string) string { - return namespace + "/" + name -} - // New returns an empty CoderWorkspace object. func (s *WorkspaceStorage) New() runtime.Object { return &aggregationv1alpha1.CoderWorkspace{} @@ -123,7 +70,7 @@ func (s *WorkspaceStorage) NewList() runtime.Object { return &aggregationv1alpha1.CoderWorkspaceList{} } -// Get returns a hardcoded CoderWorkspace by name. +// Get fetches a CoderWorkspace by organization, owner, and workspace name. func (s *WorkspaceStorage) Get(ctx context.Context, name string, _ *metav1.GetOptions) (runtime.Object, error) { if s == nil { return nil, fmt.Errorf("assertion failed: workspace storage must not be nil") @@ -135,30 +82,33 @@ func (s *WorkspaceStorage) Get(ctx context.Context, name string, _ *metav1.GetOp return nil, fmt.Errorf("assertion failed: workspace name must not be empty") } - namespace := genericapirequest.NamespaceValue(ctx) + namespace, badNamespaceErr := requiredNamespaceFromRequestContext(ctx) + if badNamespaceErr != nil { + return nil, badNamespaceErr + } - s.mu.RLock() - defer s.mu.RUnlock() + orgName, userName, workspaceName, err := coder.ParseWorkspaceName(name) + if err != nil { + return nil, apierrors.NewBadRequest(fmt.Sprintf("invalid workspace name %q: %v", name, err)) + } - if namespace != "" { - workspace, ok := s.workspaces[workspaceKey(namespace, name)] - if !ok { - return nil, apierrors.NewNotFound(aggregationv1alpha1.Resource("coderworkspaces"), name) - } - return workspace.DeepCopy(), nil + sdk, err := s.clientForNamespace(ctx, namespace) + if err != nil { + return nil, wrapClientError(err) } - workspace, found, ambiguous := s.findWorkspaceByNameLocked(name) - if ambiguous { - return nil, apierrors.NewBadRequest(fmt.Sprintf("workspace name %q is ambiguous across namespaces; specify namespace", name)) + workspace, err := sdk.WorkspaceByOwnerAndName(ctx, userName, workspaceName, codersdk.WorkspaceOptions{}) + if err != nil { + return nil, coder.MapCoderError(err, aggregationv1alpha1.Resource("coderworkspaces"), name) } - if !found { + if workspace.OrganizationName != orgName { return nil, apierrors.NewNotFound(aggregationv1alpha1.Resource("coderworkspaces"), name) } - return workspace.DeepCopy(), nil + + return convert.WorkspaceToK8s(namespace, workspace), nil } -// List returns hardcoded CoderWorkspace objects. +// List fetches CoderWorkspace objects from codersdk. func (s *WorkspaceStorage) List(ctx context.Context, _ *metainternalversion.ListOptions) (runtime.Object, error) { if s == nil { return nil, fmt.Errorf("assertion failed: workspace storage must not be nil") @@ -167,36 +117,42 @@ func (s *WorkspaceStorage) List(ctx context.Context, _ *metainternalversion.List return nil, fmt.Errorf("assertion failed: context must not be nil") } - namespace := genericapirequest.NamespaceValue(ctx) + namespace, badNamespaceErr := namespaceFromRequestContext(ctx) + if badNamespaceErr != nil { + return nil, badNamespaceErr + } + + responseNamespace, responseNamespaceErr := namespaceForListConversion(namespace, s.provider) + if responseNamespaceErr != nil { + return nil, responseNamespaceErr + } + + sdk, err := s.clientForNamespace(ctx, namespace) + if err != nil { + return nil, wrapClientError(err) + } + + workspacesResponse, err := sdk.Workspaces(ctx, codersdk.WorkspaceFilter{}) + if err != nil { + return nil, coder.MapCoderError(err, aggregationv1alpha1.Resource("coderworkspaces"), "") + } + list := &aggregationv1alpha1.CoderWorkspaceList{ TypeMeta: metav1.TypeMeta{ Kind: "CoderWorkspaceList", APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), }, - Items: make([]aggregationv1alpha1.CoderWorkspace, 0), + Items: make([]aggregationv1alpha1.CoderWorkspace, 0, len(workspacesResponse.Workspaces)), } - s.mu.RLock() - defer s.mu.RUnlock() - - keys := make([]string, 0, len(s.workspaces)) - for key := range s.workspaces { - keys = append(keys, key) - } - sort.Strings(keys) - - for _, key := range keys { - workspace := s.workspaces[key] - if namespace != "" && workspace.Namespace != namespace { - continue - } - list.Items = append(list.Items, *workspace.DeepCopy()) + for _, workspace := range workspacesResponse.Workspaces { + list.Items = append(list.Items, *convert.WorkspaceToK8s(responseNamespace, workspace)) } return list, nil } -// Create inserts a CoderWorkspace into the in-memory store. +// Create creates a CoderWorkspace through codersdk. func (s *WorkspaceStorage) Create( ctx context.Context, obj runtime.Object, @@ -213,56 +169,131 @@ func (s *WorkspaceStorage) Create( return nil, fmt.Errorf("assertion failed: object must not be nil") } - workspace, ok := obj.(*aggregationv1alpha1.CoderWorkspace) + workspaceObj, ok := obj.(*aggregationv1alpha1.CoderWorkspace) if !ok { return nil, apierrors.NewBadRequest(fmt.Sprintf("expected *CoderWorkspace, got %T", obj)) } + if createValidation != nil { + if err := createValidation(ctx, obj); err != nil { + return nil, err + } + } + if workspaceObj.Name == "" { + return nil, apierrors.NewBadRequest("metadata.name must not be empty") + } - candidate := workspace.DeepCopy() - if candidate.Name == "" { - return nil, apierrors.NewBadRequest("metadata.name is required") + namespace, badNamespaceErr := requiredNamespaceFromRequestContext(ctx) + if badNamespaceErr != nil { + return nil, badNamespaceErr + } + if workspaceObj.Namespace != "" && workspaceObj.Namespace != namespace { + return nil, apierrors.NewBadRequest( + fmt.Sprintf("metadata.namespace %q must match request namespace %q", workspaceObj.Namespace, namespace), + ) } - namespace, err := resolveWriteNamespace(ctx, candidate.Namespace) + orgName, userName, workspaceName, err := coder.ParseWorkspaceName(workspaceObj.Name) if err != nil { - return nil, err + return nil, apierrors.NewBadRequest(fmt.Sprintf("invalid workspace name %q: %v", workspaceObj.Name, err)) + } + if workspaceObj.Spec.Organization != orgName { + return nil, apierrors.NewBadRequest( + fmt.Sprintf( + "spec.organization %q must match organization %q parsed from metadata.name", + workspaceObj.Spec.Organization, + orgName, + ), + ) + } + if workspaceObj.Spec.TemplateName == "" { + return nil, apierrors.NewBadRequest("spec.templateName must not be empty") } - candidate.Namespace = namespace - ensureWorkspaceTypeMeta(candidate) - if candidate.Generation == 0 { - candidate.Generation = 1 + sdk, err := s.clientForNamespace(ctx, namespace) + if err != nil { + return nil, wrapClientError(err) } - if candidate.CreationTimestamp.IsZero() { - candidate.CreationTimestamp = metav1.Now() + + org, err := sdk.OrganizationByName(ctx, orgName) + if err != nil { + return nil, coder.MapCoderError(err, aggregationv1alpha1.Resource("coderworkspaces"), workspaceObj.Name) } - candidate.ResourceVersion = "1" - if createValidation != nil { - if err := createValidation(ctx, candidate); err != nil { - return nil, err + template, err := sdk.TemplateByName(ctx, org.ID, workspaceObj.Spec.TemplateName) + if err != nil { + return nil, coder.MapCoderError( + err, + aggregationv1alpha1.Resource("codertemplates"), + coder.BuildTemplateName(orgName, workspaceObj.Spec.TemplateName), + ) + } + + if workspaceObj.Spec.TemplateVersionID != "" { + parsedTemplateVersionID, parseErr := uuid.Parse(workspaceObj.Spec.TemplateVersionID) + if parseErr != nil { + return nil, apierrors.NewBadRequest( + fmt.Sprintf( + "invalid workspace spec: invalid templateVersionID %q: %v", + workspaceObj.Spec.TemplateVersionID, + parseErr, + ), + ) + } + + templateVersion, templateVersionErr := sdk.TemplateVersion(ctx, parsedTemplateVersionID) + if templateVersionErr != nil { + return nil, coder.MapCoderError( + templateVersionErr, + aggregationv1alpha1.Resource("coderworkspaces"), + workspaceObj.Name, + ) + } + + if templateVersion.TemplateID == nil || *templateVersion.TemplateID != template.ID { + return nil, apierrors.NewBadRequest( + fmt.Sprintf( + "spec.templateVersionID %q does not belong to template %q", + workspaceObj.Spec.TemplateVersionID, + workspaceObj.Spec.TemplateName, + ), + ) } } - key := workspaceKey(candidate.Namespace, candidate.Name) + request, err := convert.WorkspaceCreateRequestFromK8s(workspaceObj, workspaceName, template.ID) + if err != nil { + return nil, apierrors.NewBadRequest(fmt.Sprintf("invalid workspace spec: %v", err)) + } - s.mu.Lock() - defer s.mu.Unlock() + createdWorkspace, err := sdk.CreateUserWorkspace(ctx, userName, request) + if err != nil { + return nil, coder.MapCoderError(err, aggregationv1alpha1.Resource("coderworkspaces"), workspaceObj.Name) + } - if _, exists := s.workspaces[key]; exists { - return nil, apierrors.NewAlreadyExists(aggregationv1alpha1.Resource("coderworkspaces"), candidate.Name) + if !workspaceObj.Spec.Running { + stopBuild, stopErr := sdk.CreateWorkspaceBuild(ctx, createdWorkspace.ID, codersdk.CreateWorkspaceBuildRequest{ + Transition: codersdk.WorkspaceTransitionStop, + }) + if stopErr == nil { + createdWorkspace.LatestBuild = stopBuild + if !stopBuild.UpdatedAt.IsZero() { + createdWorkspace.UpdatedAt = stopBuild.UpdatedAt + } + } + // The workspace creation already succeeded. Returning a stop transition error here + // would cause client retries to fail with AlreadyExists, while the desired stop + // transition can be retried safely via a subsequent Update. } - s.workspaces[key] = candidate.DeepCopy() - return candidate.DeepCopy(), nil + return convert.WorkspaceToK8s(namespace, createdWorkspace), nil } -// Update modifies an existing CoderWorkspace in the in-memory store. +// Update updates workspace run state through codersdk build transitions. func (s *WorkspaceStorage) Update( ctx context.Context, name string, objInfo rest.UpdatedObjectInfo, - createValidation rest.ValidateObjectFunc, + _ rest.ValidateObjectFunc, updateValidation rest.ValidateObjectUpdateFunc, forceAllowCreate bool, _ *metav1.UpdateOptions, @@ -279,124 +310,125 @@ func (s *WorkspaceStorage) Update( if objInfo == nil { return nil, false, fmt.Errorf("assertion failed: updated object info must not be nil") } - - namespace := genericapirequest.NamespaceValue(ctx) - if namespace == "" { - return nil, false, apierrors.NewBadRequest("namespace is required") + if forceAllowCreate { + return nil, false, apierrors.NewMethodNotSupported( + aggregationv1alpha1.Resource("coderworkspaces"), + "create on update", + ) } - key := workspaceKey(namespace, name) - - s.mu.Lock() - defer s.mu.Unlock() - - existing, exists := s.workspaces[key] - if !exists { - if !forceAllowCreate { - return nil, false, apierrors.NewNotFound(aggregationv1alpha1.Resource("coderworkspaces"), name) - } - - createdObj, err := objInfo.UpdatedObject(ctx, &aggregationv1alpha1.CoderWorkspace{}) - if err != nil { - return nil, false, err - } - createdWorkspace, ok := createdObj.(*aggregationv1alpha1.CoderWorkspace) - if !ok { - return nil, false, apierrors.NewBadRequest(fmt.Sprintf("expected *CoderWorkspace, got %T", createdObj)) - } - - candidate := createdWorkspace.DeepCopy() - if candidate.Name == "" { - candidate.Name = name - } - if candidate.Name != name { - return nil, false, apierrors.NewBadRequest(fmt.Sprintf("metadata.name %q must match request name %q", candidate.Name, name)) - } - if candidate.Namespace == "" { - candidate.Namespace = namespace - } - if candidate.Namespace != namespace { - return nil, false, apierrors.NewBadRequest( - fmt.Sprintf("metadata.namespace %q must match request namespace %q", candidate.Namespace, namespace), - ) - } + namespace, badNamespaceErr := requiredNamespaceFromRequestContext(ctx) + if badNamespaceErr != nil { + return nil, false, badNamespaceErr + } - ensureWorkspaceTypeMeta(candidate) - if candidate.Generation == 0 { - candidate.Generation = 1 - } - if candidate.CreationTimestamp.IsZero() { - candidate.CreationTimestamp = metav1.Now() - } - candidate.ResourceVersion = "1" + orgName, userName, workspaceName, err := coder.ParseWorkspaceName(name) + if err != nil { + return nil, false, apierrors.NewBadRequest(fmt.Sprintf("invalid workspace name %q: %v", name, err)) + } - if createValidation != nil { - if err := createValidation(ctx, candidate); err != nil { - return nil, false, err - } - } + sdk, err := s.clientForNamespace(ctx, namespace) + if err != nil { + return nil, false, wrapClientError(err) + } - s.workspaces[key] = candidate.DeepCopy() - return candidate.DeepCopy(), true, nil + currentWorkspace, err := sdk.WorkspaceByOwnerAndName(ctx, userName, workspaceName, codersdk.WorkspaceOptions{}) + if err != nil { + return nil, false, coder.MapCoderError(err, aggregationv1alpha1.Resource("coderworkspaces"), name) + } + if currentWorkspace.OrganizationName != orgName { + return nil, false, apierrors.NewNotFound(aggregationv1alpha1.Resource("coderworkspaces"), name) } - updatedObj, err := objInfo.UpdatedObject(ctx, existing.DeepCopy()) + currentK8sObj := convert.WorkspaceToK8s(namespace, currentWorkspace) + desiredObjRuntime, err := objInfo.UpdatedObject(ctx, currentK8sObj.DeepCopy()) if err != nil { return nil, false, err } - updatedWorkspace, ok := updatedObj.(*aggregationv1alpha1.CoderWorkspace) - if !ok { - return nil, false, apierrors.NewBadRequest(fmt.Sprintf("expected *CoderWorkspace, got %T", updatedObj)) - } - candidate := updatedWorkspace.DeepCopy() - if candidate.Name == "" { - candidate.Name = name + desiredObj, ok := desiredObjRuntime.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + return nil, false, apierrors.NewBadRequest( + fmt.Sprintf("updated object must be *CoderWorkspace, got %T", desiredObjRuntime), + ) } - if candidate.Name != name { - return nil, false, apierrors.NewBadRequest(fmt.Sprintf("metadata.name %q must match request name %q", candidate.Name, name)) + if desiredObj.Name != "" && desiredObj.Name != name { + return nil, false, apierrors.NewBadRequest( + fmt.Sprintf("updated object metadata.name %q must match request name %q", desiredObj.Name, name), + ) } - if candidate.Namespace == "" { - candidate.Namespace = namespace + if desiredObj.Spec.Organization != "" && desiredObj.Spec.Organization != orgName { + return nil, false, apierrors.NewBadRequest( + fmt.Sprintf( + "updated object spec.organization %q must match organization %q parsed from metadata.name", + desiredObj.Spec.Organization, + orgName, + ), + ) } - if candidate.Namespace != namespace { + if desiredObj.Namespace != "" && desiredObj.Namespace != namespace { return nil, false, apierrors.NewBadRequest( - fmt.Sprintf("metadata.namespace %q must match request namespace %q", candidate.Namespace, namespace), + fmt.Sprintf("metadata.namespace %q does not match request namespace %q", desiredObj.Namespace, namespace), ) } - - if candidate.ResourceVersion == "" { + if desiredObj.ResourceVersion == "" { return nil, false, apierrors.NewBadRequest("metadata.resourceVersion is required for update") } - if candidate.ResourceVersion != existing.ResourceVersion { + if desiredObj.ResourceVersion != currentK8sObj.ResourceVersion { return nil, false, apierrors.NewConflict( aggregationv1alpha1.Resource("coderworkspaces"), name, - fmt.Errorf("resourceVersion %q does not match current value %q", candidate.ResourceVersion, existing.ResourceVersion), + fmt.Errorf( + "resource version mismatch: got %q, current is %q", + desiredObj.ResourceVersion, + currentK8sObj.ResourceVersion, + ), ) } - candidate.Status = existing.Status - candidate.CreationTimestamp = existing.CreationTimestamp - candidate.Generation = existing.Generation + 1 - candidateFinalResourceVersion, err := incrementResourceVersion(existing.ResourceVersion) - if err != nil { - return nil, false, err - } - candidate.ResourceVersion = candidateFinalResourceVersion - ensureWorkspaceTypeMeta(candidate) - if updateValidation != nil { - if err := updateValidation(ctx, candidate, existing); err != nil { + if err := updateValidation(ctx, desiredObj, currentK8sObj); err != nil { return nil, false, err } } - s.workspaces[key] = candidate.DeepCopy() - return candidate.DeepCopy(), false, nil + // Workspace updates via codersdk are currently limited to workspace build + // transitions, which map only to spec.running toggles in this API. + if desiredObj.Spec.Organization != currentK8sObj.Spec.Organization || + desiredObj.Spec.TemplateName != currentK8sObj.Spec.TemplateName || + (desiredObj.Spec.TemplateVersionID != "" && desiredObj.Spec.TemplateVersionID != currentK8sObj.Spec.TemplateVersionID) || + (desiredObj.Spec.TTLMillis != nil && !equalInt64Ptr(desiredObj.Spec.TTLMillis, currentK8sObj.Spec.TTLMillis)) || + (desiredObj.Spec.AutostartSchedule != nil && !equalStringPtr(desiredObj.Spec.AutostartSchedule, currentK8sObj.Spec.AutostartSchedule)) { + return nil, false, apierrors.NewBadRequest( + "workspace update only supports changing spec.running; other spec fields are immutable", + ) + } + + if desiredObj.Spec.Running == currentK8sObj.Spec.Running { + return currentK8sObj, false, nil + } + + transition := codersdk.WorkspaceTransitionStop + if desiredObj.Spec.Running { + transition = codersdk.WorkspaceTransitionStart + } + + build, err := sdk.CreateWorkspaceBuild(ctx, currentWorkspace.ID, codersdk.CreateWorkspaceBuildRequest{ + Transition: transition, + }) + if err != nil { + return nil, false, coder.MapCoderError(err, aggregationv1alpha1.Resource("coderworkspaces"), name) + } + + currentWorkspace.LatestBuild = build + if !build.UpdatedAt.IsZero() { + currentWorkspace.UpdatedAt = build.UpdatedAt + } + + return convert.WorkspaceToK8s(namespace, currentWorkspace), false, nil } -// Delete removes a CoderWorkspace from the in-memory store. +// Delete requests workspace deletion through a codersdk build transition. func (s *WorkspaceStorage) Delete( ctx context.Context, name string, @@ -413,49 +445,45 @@ func (s *WorkspaceStorage) Delete( return nil, false, fmt.Errorf("assertion failed: workspace name must not be empty") } - namespace := genericapirequest.NamespaceValue(ctx) + namespace, badNamespaceErr := requiredNamespaceFromRequestContext(ctx) + if badNamespaceErr != nil { + return nil, false, badNamespaceErr + } - s.mu.Lock() - defer s.mu.Unlock() + orgName, userName, workspaceName, err := coder.ParseWorkspaceName(name) + if err != nil { + return nil, false, apierrors.NewBadRequest(fmt.Sprintf("invalid workspace name %q: %v", name, err)) + } - var ( - key string - workspace *aggregationv1alpha1.CoderWorkspace - ) - if namespace != "" { - key = workspaceKey(namespace, name) - workspace = s.workspaces[key] - } else { - matchedKeys := make([]string, 0) - for candidateKey, candidateWorkspace := range s.workspaces { - if candidateWorkspace.Name == name { - matchedKeys = append(matchedKeys, candidateKey) - } - } - if len(matchedKeys) > 1 { - return nil, false, apierrors.NewBadRequest( - fmt.Sprintf("workspace name %q is ambiguous across namespaces; specify namespace", name), - ) - } - if len(matchedKeys) == 1 { - key = matchedKeys[0] - workspace = s.workspaces[key] - } + sdk, err := s.clientForNamespace(ctx, namespace) + if err != nil { + return nil, false, wrapClientError(err) } - if workspace == nil { + workspace, err := sdk.WorkspaceByOwnerAndName(ctx, userName, workspaceName, codersdk.WorkspaceOptions{}) + if err != nil { + return nil, false, coder.MapCoderError(err, aggregationv1alpha1.Resource("coderworkspaces"), name) + } + if workspace.OrganizationName != orgName { return nil, false, apierrors.NewNotFound(aggregationv1alpha1.Resource("coderworkspaces"), name) } if deleteValidation != nil { - if err := deleteValidation(ctx, workspace.DeepCopy()); err != nil { - return nil, false, err + if validationErr := deleteValidation(ctx, convert.WorkspaceToK8s(namespace, workspace)); validationErr != nil { + return nil, false, validationErr } } - deleted := workspace.DeepCopy() - delete(s.workspaces, key) - return deleted, true, nil + _, err = sdk.CreateWorkspaceBuild(ctx, workspace.ID, codersdk.CreateWorkspaceBuildRequest{ + Transition: codersdk.WorkspaceTransitionDelete, + }) + if err != nil { + return nil, false, coder.MapCoderError(err, aggregationv1alpha1.Resource("coderworkspaces"), name) + } + + // Deletion is asynchronous in Coder: we only enqueue a delete build transition here. + // Report deleted=false so Kubernetes callers know the resource is not gone yet. + return &metav1.Status{Status: metav1.StatusSuccess}, false, nil } // ConvertToTable converts a workspace object or list into kubectl table output. @@ -470,32 +498,78 @@ func (s *WorkspaceStorage) ConvertToTable(ctx context.Context, object, tableOpti return s.tableConvertor.ConvertToTable(ctx, object, tableOptions) } -func ensureWorkspaceTypeMeta(workspace *aggregationv1alpha1.CoderWorkspace) { - if workspace == nil { - panic("assertion failed: workspace must not be nil") +func (s *WorkspaceStorage) clientForNamespace(ctx context.Context, namespace string) (*codersdk.Client, error) { + if s.provider == nil { + return nil, fmt.Errorf("assertion failed: workspace client provider must not be nil") } - workspace.TypeMeta = metav1.TypeMeta{ - Kind: "CoderWorkspace", - APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), + + sdk, err := s.provider.ClientForNamespace(ctx, namespace) + if err != nil { + return nil, fmt.Errorf("resolve codersdk client for namespace %q: %w", namespace, err) + } + if sdk == nil { + return nil, fmt.Errorf("assertion failed: workspace client provider returned nil codersdk client") } + + return sdk, nil } -func (s *WorkspaceStorage) findWorkspaceByNameLocked(name string) (*aggregationv1alpha1.CoderWorkspace, bool, bool) { - matchedKeys := make([]string, 0) - for key, workspace := range s.workspaces { - if workspace.Name == name { - matchedKeys = append(matchedKeys, key) - } +func namespaceFromRequestContext(ctx context.Context) (string, error) { + if ctx == nil { + return "", fmt.Errorf("assertion failed: context must not be nil") } - if len(matchedKeys) == 0 { - return nil, false, false + + return genericapirequest.NamespaceValue(ctx), nil +} + +func requiredNamespaceFromRequestContext(ctx context.Context) (string, error) { + namespace, err := namespaceFromRequestContext(ctx) + if err != nil { + return "", err } - if len(matchedKeys) > 1 { - return nil, false, true + if namespace == "" { + return "", apierrors.NewBadRequest("namespace is required") + } + + return namespace, nil +} + +func namespaceForListConversion(requestNamespace string, provider coder.ClientProvider) (string, error) { + if requestNamespace != "" { + return requestNamespace, nil } - workspace := s.workspaces[matchedKeys[0]] - if workspace == nil { - return nil, false, false + if provider == nil { + return "", fmt.Errorf("assertion failed: client provider must not be nil") } - return workspace, true, false + + staticProvider, ok := provider.(*coder.StaticClientProvider) + if !ok || staticProvider.Namespace == "" { + return "", apierrors.NewServiceUnavailable( + "all-namespaces list requires a namespace-pinned static provider; configure --coder-namespace", + ) + } + + return staticProvider.Namespace, nil +} + +func equalInt64Ptr(a, b *int64) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + + return *a == *b +} + +func equalStringPtr(a, b *string) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + + return *a == *b } diff --git a/internal/aggregated/storage/workspace_test.go b/internal/aggregated/storage/workspace_test.go deleted file mode 100644 index b77cc98f..00000000 --- a/internal/aggregated/storage/workspace_test.go +++ /dev/null @@ -1,231 +0,0 @@ -package storage - -import ( - "context" - "testing" - "time" - - apierrors "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - genericapirequest "k8s.io/apiserver/pkg/endpoints/request" - "k8s.io/apiserver/pkg/registry/rest" - - aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" -) - -func TestWorkspaceStorageCRUDLifecycle(t *testing.T) { - t.Helper() - - workspaceStorage := NewWorkspaceStorage() - ctx := genericapirequest.WithNamespace(context.Background(), "default") - - createdObj, err := workspaceStorage.Create(ctx, &aggregationv1alpha1.CoderWorkspace{ - ObjectMeta: metav1.ObjectMeta{Name: "unit-workspace"}, - Spec: aggregationv1alpha1.CoderWorkspaceSpec{Running: true}, - }, rest.ValidateAllObjectFunc, nil) - if err != nil { - t.Fatalf("create workspace: %v", err) - } - - created, ok := createdObj.(*aggregationv1alpha1.CoderWorkspace) - if !ok { - t.Fatalf("expected *CoderWorkspace from create, got %T", createdObj) - } - if created.Namespace != "default" { - t.Fatalf("expected namespace default, got %q", created.Namespace) - } - if created.ResourceVersion != "1" { - t.Fatalf("expected resourceVersion 1, got %q", created.ResourceVersion) - } - if created.Generation != 1 { - t.Fatalf("expected generation 1, got %d", created.Generation) - } - - toUpdate := created.DeepCopy() - toUpdate.Spec.Running = false - toUpdate.ResourceVersion = created.ResourceVersion - updatedObj, createdOnUpdate, err := workspaceStorage.Update( - ctx, - toUpdate.Name, - rest.DefaultUpdatedObjectInfo(toUpdate), - nil, - nil, - false, - nil, - ) - if err != nil { - t.Fatalf("update workspace: %v", err) - } - if createdOnUpdate { - t.Fatal("expected update of existing workspace, got createdOnUpdate=true") - } - - updated, ok := updatedObj.(*aggregationv1alpha1.CoderWorkspace) - if !ok { - t.Fatalf("expected *CoderWorkspace from update, got %T", updatedObj) - } - if updated.Spec.Running { - t.Fatalf("expected running=false after update, got %+v", updated.Spec) - } - if updated.ResourceVersion == created.ResourceVersion { - t.Fatalf("expected resourceVersion to change, got %q", updated.ResourceVersion) - } - if updated.Generation != created.Generation+1 { - t.Fatalf("expected generation increment to %d, got %d", created.Generation+1, updated.Generation) - } - - deletedObj, deletedNow, err := workspaceStorage.Delete(ctx, created.Name, nil, nil) - if err != nil { - t.Fatalf("delete workspace: %v", err) - } - if !deletedNow { - t.Fatal("expected immediate delete") - } - if _, ok := deletedObj.(*aggregationv1alpha1.CoderWorkspace); !ok { - t.Fatalf("expected *CoderWorkspace from delete, got %T", deletedObj) - } - - _, err = workspaceStorage.Get(ctx, created.Name, nil) - if !apierrors.IsNotFound(err) { - t.Fatalf("expected NotFound after delete, got %v", err) - } -} - -func TestWorkspaceStorageCreateAlreadyExists(t *testing.T) { - t.Helper() - - workspaceStorage := NewWorkspaceStorage() - ctx := genericapirequest.WithNamespace(context.Background(), "default") - - _, err := workspaceStorage.Create(ctx, &aggregationv1alpha1.CoderWorkspace{ - ObjectMeta: metav1.ObjectMeta{Name: "dev-workspace"}, - Spec: aggregationv1alpha1.CoderWorkspaceSpec{Running: true}, - }, nil, nil) - if !apierrors.IsAlreadyExists(err) { - t.Fatalf("expected AlreadyExists error, got %v", err) - } -} - -func TestWorkspaceStorageUpdateRejectsNamespaceChange(t *testing.T) { - t.Helper() - - workspaceStorage := NewWorkspaceStorage() - ctx := genericapirequest.WithNamespace(context.Background(), "default") - - currentObj, err := workspaceStorage.Get(ctx, "dev-workspace", nil) - if err != nil { - t.Fatalf("get workspace: %v", err) - } - current := currentObj.(*aggregationv1alpha1.CoderWorkspace) - - modified := current.DeepCopy() - modified.Namespace = "sandbox" - modified.ResourceVersion = current.ResourceVersion - - _, _, err = workspaceStorage.Update( - ctx, - modified.Name, - rest.DefaultUpdatedObjectInfo(modified), - nil, - nil, - false, - nil, - ) - if !apierrors.IsBadRequest(err) { - t.Fatalf("expected BadRequest for namespace mismatch, got %v", err) - } -} - -func TestWorkspaceStorageUpdateIgnoresStatusWrites(t *testing.T) { - t.Helper() - - workspaceStorage := NewWorkspaceStorage() - ctx := genericapirequest.WithNamespace(context.Background(), "default") - - currentObj, err := workspaceStorage.Get(ctx, "dev-workspace", nil) - if err != nil { - t.Fatalf("get workspace: %v", err) - } - current := currentObj.(*aggregationv1alpha1.CoderWorkspace) - if current.Status.AutoShutdown == nil { - t.Fatal("expected seeded workspace status autoShutdown") - } - - modified := current.DeepCopy() - modified.Spec.Running = !current.Spec.Running - modified.ResourceVersion = current.ResourceVersion - overrideDeadline := metav1.NewTime(time.Date(2040, time.January, 1, 0, 0, 0, 0, time.UTC)) - modified.Status.AutoShutdown = &overrideDeadline - - updatedObj, _, err := workspaceStorage.Update( - ctx, - modified.Name, - rest.DefaultUpdatedObjectInfo(modified), - nil, - nil, - false, - nil, - ) - if err != nil { - t.Fatalf("update workspace: %v", err) - } - - updated := updatedObj.(*aggregationv1alpha1.CoderWorkspace) - if updated.Status.AutoShutdown == nil { - t.Fatal("expected status autoShutdown to remain present") - } - if !updated.Status.AutoShutdown.Equal(current.Status.AutoShutdown) { - t.Fatalf("expected status to remain unchanged, got %s want %s", updated.Status.AutoShutdown, current.Status.AutoShutdown) - } -} - -func TestWorkspaceStorageDeleteAmbiguousWithoutNamespace(t *testing.T) { - t.Helper() - - workspaceStorage := NewWorkspaceStorage() - - _, err := workspaceStorage.Create( - genericapirequest.WithNamespace(context.Background(), "sandbox"), - &aggregationv1alpha1.CoderWorkspace{ObjectMeta: metav1.ObjectMeta{Name: "dev-workspace"}}, - nil, - nil, - ) - if err != nil { - t.Fatalf("seed same-name workspace in sandbox namespace: %v", err) - } - - _, _, err = workspaceStorage.Delete(context.Background(), "dev-workspace", nil, nil) - if !apierrors.IsBadRequest(err) { - t.Fatalf("expected BadRequest for ambiguous delete, got %v", err) - } -} - -func TestWorkspaceStorageUpdateRequiresResourceVersion(t *testing.T) { - t.Helper() - - workspaceStorage := NewWorkspaceStorage() - ctx := genericapirequest.WithNamespace(context.Background(), "default") - - currentObj, err := workspaceStorage.Get(ctx, "dev-workspace", nil) - if err != nil { - t.Fatalf("get workspace: %v", err) - } - current := currentObj.(*aggregationv1alpha1.CoderWorkspace) - - modified := current.DeepCopy() - modified.Spec.Running = !current.Spec.Running - modified.ResourceVersion = "" - - _, _, err = workspaceStorage.Update( - ctx, - modified.Name, - rest.DefaultUpdatedObjectInfo(modified), - nil, - nil, - false, - nil, - ) - if !apierrors.IsBadRequest(err) { - t.Fatalf("expected BadRequest when resourceVersion is missing, got %v", err) - } -} diff --git a/internal/app/apiserverapp/apiserverapp.go b/internal/app/apiserverapp/apiserverapp.go index 18182446..bd8ec0a2 100644 --- a/internal/app/apiserverapp/apiserverapp.go +++ b/internal/app/apiserverapp/apiserverapp.go @@ -4,7 +4,14 @@ package apiserverapp import ( "context" "fmt" - + "log" + "net" + "net/url" + "strings" + "time" + + "github.com/coder/coder/v2/codersdk" + apierrors "k8s.io/apimachinery/pkg/api/errors" metainternalversion "k8s.io/apimachinery/pkg/apis/meta/internalversion" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -23,6 +30,7 @@ import ( "k8s.io/kube-openapi/pkg/validation/spec" aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" + "github.com/coder/coder-k8s/internal/aggregated/coder" "github.com/coder/coder-k8s/internal/aggregated/storage" ) @@ -32,6 +40,100 @@ const ( serverName = "coder-k8s-aggregated-apiserver" ) +// Options configures aggregated-apiserver bootstrap behavior. +type Options struct { + // SecureServingPort used when Listener is nil. Default: DefaultSecureServingPort. + SecureServingPort int + // Listener allows tests to bind to 127.0.0.1:0. + Listener net.Listener + // CoderURL is an optional fallback URL when CoderControlPlane status has no URL. + CoderURL string + // CoderSessionToken is the admin session token. + CoderSessionToken string + // CoderNamespace restricts the provider to serve only this namespace. + // When non-empty, requests to other namespaces are rejected. + CoderNamespace string + // CoderRequestTimeout for SDK calls. Default 30s. + CoderRequestTimeout time.Duration +} + +type errClientProvider struct { + serviceUnavailableMessage string +} + +var _ coder.ClientProvider = (*errClientProvider)(nil) + +func (p *errClientProvider) ClientForNamespace(ctx context.Context, _ string) (*codersdk.Client, error) { + if p == nil { + return nil, fmt.Errorf("assertion failed: err client provider must not be nil") + } + if ctx == nil { + return nil, fmt.Errorf("assertion failed: context must not be nil") + } + if p.serviceUnavailableMessage == "" { + return nil, fmt.Errorf("assertion failed: service unavailable message must not be empty") + } + + return nil, apierrors.NewServiceUnavailable(p.serviceUnavailableMessage) +} + +func buildClientProvider(opts Options, requestTimeout time.Duration) (coder.ClientProvider, error) { + if requestTimeout <= 0 { + return nil, fmt.Errorf("assertion failed: request timeout must be positive") + } + + coderURL := strings.TrimSpace(opts.CoderURL) + sessionToken := strings.TrimSpace(opts.CoderSessionToken) + missing := make([]string, 0, 2) + if coderURL == "" { + missing = append(missing, "coder URL") + } + if sessionToken == "" { + missing = append(missing, "coder session token") + } + if len(missing) > 0 { + message := fmt.Sprintf( + "coder client provider is not configured: missing %s; configure --coder-url and --coder-session-token", + strings.Join(missing, " and "), + ) + if len(missing) == 2 { + return &errClientProvider{serviceUnavailableMessage: message}, nil + } + + return nil, fmt.Errorf("coder client provider is partially configured: %s", message) + } + + coderNamespace := strings.TrimSpace(opts.CoderNamespace) + if coderNamespace == "" { + return nil, fmt.Errorf("coder client provider namespace is not configured: configure --coder-namespace") + } + + parsedCoderURL, err := url.Parse(coderURL) + if err != nil { + return nil, fmt.Errorf("parse coder URL %q: %w", coderURL, err) + } + if parsedCoderURL == nil { + return nil, fmt.Errorf("assertion failed: parsed coder URL must not be nil") + } + + provider, err := coder.NewStaticClientProvider( + coder.Config{ + CoderURL: parsedCoderURL, + SessionToken: sessionToken, + RequestTimeout: requestTimeout, + }, + coderNamespace, + ) + if err != nil { + return nil, err + } + if provider == nil { + return nil, fmt.Errorf("assertion failed: coder client provider is nil after successful construction") + } + + return provider, nil +} + // NewScheme builds the runtime scheme used by the aggregated API server. func NewScheme() *runtime.Scheme { scheme := runtime.NewScheme() @@ -96,10 +198,17 @@ func NewRecommendedConfig( } // NewAPIGroupInfo creates APIGroupInfo for the aggregation.coder.com API group. -func NewAPIGroupInfo(scheme *runtime.Scheme, codecs serializer.CodecFactory) (*genericapiserver.APIGroupInfo, error) { +func NewAPIGroupInfo( + scheme *runtime.Scheme, + codecs serializer.CodecFactory, + provider coder.ClientProvider, +) (*genericapiserver.APIGroupInfo, error) { if scheme == nil { return nil, fmt.Errorf("assertion failed: scheme must not be nil") } + if provider == nil { + return nil, fmt.Errorf("assertion failed: coder client provider must not be nil") + } parameterCodec := runtime.NewParameterCodec(scheme) apiGroupInfo := genericapiserver.NewDefaultAPIGroupInfo( @@ -109,8 +218,8 @@ func NewAPIGroupInfo(scheme *runtime.Scheme, codecs serializer.CodecFactory) (*g codecs, ) apiGroupInfo.VersionedResourcesStorageMap[aggregationv1alpha1.SchemeGroupVersion.Version] = map[string]rest.Storage{ - "coderworkspaces": storage.NewWorkspaceStorage(), - "codertemplates": storage.NewTemplateStorage(), + "coderworkspaces": storage.NewWorkspaceStorage(provider), + "codertemplates": storage.NewTemplateStorage(provider), } return &apiGroupInfo, nil } @@ -146,9 +255,34 @@ func NewGenericAPIServer(recommendedConfig *genericapiserver.RecommendedConfig) // Run starts the aggregated API server application mode. func Run(ctx context.Context) error { + return RunWithOptions(ctx, Options{}) +} + +// RunWithOptions starts the aggregated API server application mode. +func RunWithOptions(ctx context.Context, opts Options) error { if ctx == nil { return fmt.Errorf("assertion failed: context must not be nil") } + if opts.CoderRequestTimeout < 0 { + return fmt.Errorf("assertion failed: coder request timeout must not be negative") + } + + requestTimeout := opts.CoderRequestTimeout + if requestTimeout == 0 { + requestTimeout = 30 * time.Second + } + + provider, err := buildClientProvider(opts, requestTimeout) + if err != nil { + return fmt.Errorf("build coder client provider: %w", err) + } + if provider == nil { + return fmt.Errorf("assertion failed: coder client provider is nil after successful construction") + } + + if errProvider, ok := provider.(*errClientProvider); ok { + log.Printf("warning: %s", errProvider.serviceUnavailableMessage) + } scheme := NewScheme() if scheme == nil { @@ -157,7 +291,18 @@ func Run(ctx context.Context) error { codecs := serializer.NewCodecFactory(scheme) secureServingOptions := genericoptions.NewSecureServingOptions() - secureServingOptions.BindPort = DefaultSecureServingPort + secureServingPort := opts.SecureServingPort + if secureServingPort == 0 { + secureServingPort = DefaultSecureServingPort + } + if secureServingPort < 0 { + return fmt.Errorf("assertion failed: secure serving port must not be negative") + } + secureServingOptions.BindPort = secureServingPort + if opts.Listener != nil { + secureServingOptions.Listener = opts.Listener + secureServingOptions.BindPort = 0 + } secureServingOptions.ServerCert.CertDirectory = "" secureServingOptions.ServerCert.PairName = "" @@ -171,7 +316,7 @@ func Run(ctx context.Context) error { return err } - apiGroupInfo, err := NewAPIGroupInfo(scheme, codecs) + apiGroupInfo, err := NewAPIGroupInfo(scheme, codecs, provider) if err != nil { return fmt.Errorf("build API group info: %w", err) } @@ -190,6 +335,8 @@ func getOpenAPIDefinitions(_ openapicommon.ReferenceCallback) map[string]openapi boolSchema := spec.Schema{SchemaProps: spec.SchemaProps{Type: []string{"boolean"}}} dateTimeSchema := spec.Schema{SchemaProps: spec.SchemaProps{Type: []string{"string"}, Format: "date-time"}} + int64Schema := spec.Schema{SchemaProps: spec.SchemaProps{Type: []string{"integer"}, Format: "int64"}} + stringSchema := spec.Schema{SchemaProps: spec.SchemaProps{Type: []string{"string"}}} workspaceSchema := spec.Schema{ SchemaProps: spec.SchemaProps{ @@ -199,7 +346,12 @@ func getOpenAPIDefinitions(_ openapicommon.ReferenceCallback) map[string]openapi SchemaProps: spec.SchemaProps{ Type: []string{"object"}, Properties: map[string]spec.Schema{ - "running": boolSchema, + "organization": stringSchema, + "templateName": stringSchema, + "templateVersionID": stringSchema, + "running": boolSchema, + "ttlMillis": int64Schema, + "autostartSchedule": stringSchema, }, }, }, @@ -207,7 +359,14 @@ func getOpenAPIDefinitions(_ openapicommon.ReferenceCallback) map[string]openapi SchemaProps: spec.SchemaProps{ Type: []string{"object"}, Properties: map[string]spec.Schema{ - "autoShutdown": dateTimeSchema, + "id": stringSchema, + "ownerName": stringSchema, + "organizationName": stringSchema, + "templateName": stringSchema, + "latestBuildID": stringSchema, + "latestBuildStatus": stringSchema, + "autoShutdown": dateTimeSchema, + "lastUsedAt": dateTimeSchema, }, }, }, @@ -223,7 +382,12 @@ func getOpenAPIDefinitions(_ openapicommon.ReferenceCallback) map[string]openapi SchemaProps: spec.SchemaProps{ Type: []string{"object"}, Properties: map[string]spec.Schema{ - "running": boolSchema, + "organization": stringSchema, + "versionID": stringSchema, + "displayName": stringSchema, + "description": stringSchema, + "icon": stringSchema, + "running": boolSchema, }, }, }, @@ -231,7 +395,12 @@ func getOpenAPIDefinitions(_ openapicommon.ReferenceCallback) map[string]openapi SchemaProps: spec.SchemaProps{ Type: []string{"object"}, Properties: map[string]spec.Schema{ - "autoShutdown": dateTimeSchema, + "id": stringSchema, + "organizationName": stringSchema, + "activeVersionID": stringSchema, + "deprecated": boolSchema, + "updatedAt": dateTimeSchema, + "autoShutdown": dateTimeSchema, }, }, }, diff --git a/internal/app/apiserverapp/apiserverapp_test.go b/internal/app/apiserverapp/apiserverapp_test.go index cf0fb8c9..7e3dee5a 100644 --- a/internal/app/apiserverapp/apiserverapp_test.go +++ b/internal/app/apiserverapp/apiserverapp_test.go @@ -2,15 +2,21 @@ package apiserverapp import ( "context" + "errors" "net" "net/http/httptest" + "net/url" + "strings" "testing" + "time" + apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/runtime/serializer" genericoptions "k8s.io/apiserver/pkg/server/options" aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" + coderhelper "github.com/coder/coder-k8s/internal/aggregated/coder" ) func TestNewSchemeRegistersAggregationKinds(t *testing.T) { @@ -67,7 +73,22 @@ func TestInstallAPIGroupRegistersDiscovery(t *testing.T) { } defer server.Destroy() - apiGroupInfo, err := NewAPIGroupInfo(scheme, codecs) + coderURL, err := url.Parse("http://localhost:8080") + if err != nil { + t.Fatalf("parse test coder URL: %v", err) + } + provider, err := coderhelper.NewStaticClientProvider( + coderhelper.Config{ + CoderURL: coderURL, + SessionToken: "test-session-token", + }, + "", + ) + if err != nil { + t.Fatalf("build static client provider: %v", err) + } + + apiGroupInfo, err := NewAPIGroupInfo(scheme, codecs, provider) if err != nil { t.Fatalf("build API group info: %v", err) } @@ -106,3 +127,175 @@ func TestInstallAPIGroupRegistersDiscovery(t *testing.T) { t.Fatalf("expected discovery registration for group %s", aggregationv1alpha1.SchemeGroupVersion.Group) } } + +func TestBuildClientProviderDefersMissingCoderConfigAsServiceUnavailable(t *testing.T) { + t.Parallel() + + provider, err := buildClientProvider(Options{}, 30*time.Second) + if err != nil { + t.Fatalf("expected missing coder config to return a deferred-error provider, got %v", err) + } + if provider == nil { + t.Fatal("expected non-nil provider when coder config is missing") + } + + sdkClient, err := provider.ClientForNamespace(context.Background(), "control-plane") + if sdkClient != nil { + t.Fatalf("expected nil sdk client when coder config is missing, got %T", sdkClient) + } + if !apierrors.IsServiceUnavailable(err) { + t.Fatalf("expected ServiceUnavailable when provider is not configured, got %v", err) + } + if err == nil || !strings.Contains(err.Error(), "configure --coder-url and --coder-session-token") { + t.Fatalf("expected missing-config error message, got %v", err) + } +} + +func TestBuildClientProviderRejectsPartialCoderConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opts Options + }{ + { + name: "missing coder URL", + opts: Options{CoderSessionToken: "test-session-token"}, + }, + { + name: "missing coder session token", + opts: Options{CoderURL: "https://coder.example.com"}, + }, + } + + for _, testCase := range tests { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + _, err := buildClientProvider(testCase.opts, 30*time.Second) + if err == nil { + t.Fatal("expected partial coder config to fail") + } + if !strings.Contains(err.Error(), "partially configured") { + t.Fatalf("expected partial-config error, got %v", err) + } + }) + } +} + +func TestRunWithOptionsRejectsPartialCoderConfig(t *testing.T) { + t.Parallel() + + err := RunWithOptions(context.Background(), Options{CoderURL: "https://coder.example.com"}) + if err == nil { + t.Fatal("expected partial coder config to fail startup") + } + if !strings.Contains(err.Error(), "partially configured") { + t.Fatalf("expected partial-config startup error, got %v", err) + } +} + +func TestBuildClientProviderRejectsMissingCoderNamespaceWhenBackendConfigured(t *testing.T) { + t.Parallel() + + _, err := buildClientProvider(Options{ + CoderURL: "https://coder.example.com", + CoderSessionToken: "test-session-token", + }, 30*time.Second) + if err == nil { + t.Fatal("expected missing coder namespace to fail when backend is otherwise configured") + } + if !strings.Contains(err.Error(), "configure --coder-namespace") { + t.Fatalf("expected missing namespace error to mention --coder-namespace, got %v", err) + } +} + +func TestRunWithOptionsRejectsMissingCoderNamespaceWhenBackendConfigured(t *testing.T) { + t.Parallel() + + err := RunWithOptions(context.Background(), Options{ + CoderURL: "https://coder.example.com", + CoderSessionToken: "test-session-token", + }) + if err == nil { + t.Fatal("expected missing coder namespace to fail startup when backend is otherwise configured") + } + if !strings.Contains(err.Error(), "configure --coder-namespace") { + t.Fatalf("expected missing namespace startup error to mention --coder-namespace, got %v", err) + } +} + +func TestRunWithOptionsStartsWithMissingCoderConfig(t *testing.T) { + t.Parallel() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("create test listener: %v", err) + } + defer func() { + _ = listener.Close() + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCh := make(chan error, 1) + go func() { + errCh <- RunWithOptions(ctx, Options{Listener: listener}) + }() + + select { + case runErr := <-errCh: + t.Fatalf("expected startup to continue with deferred coder config, got %v", runErr) + case <-time.After(300 * time.Millisecond): + } + + cancel() + + select { + case runErr := <-errCh: + if runErr != nil && !errors.Is(runErr, context.Canceled) { + t.Fatalf("expected graceful shutdown after cancellation, got %v", runErr) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for aggregated apiserver shutdown") + } +} + +func TestBuildClientProviderReturnsStaticProviderWithCoderConfig(t *testing.T) { + t.Parallel() + + provider, err := buildClientProvider(Options{ + CoderURL: "https://coder.example.com", + CoderSessionToken: "test-session-token", + CoderNamespace: "control-plane", + }, 30*time.Second) + if err != nil { + t.Fatalf("build client provider: %v", err) + } + + staticProvider, ok := provider.(*coderhelper.StaticClientProvider) + if !ok { + t.Fatalf("expected *coder.StaticClientProvider, got %T", provider) + } + if got, want := staticProvider.Namespace, "control-plane"; got != want { + t.Fatalf("expected provider namespace %q, got %q", want, got) + } + + sdkClient, err := staticProvider.ClientForNamespace(context.Background(), "control-plane") + if err != nil { + t.Fatalf("resolve static client for namespace: %v", err) + } + if sdkClient == nil { + t.Fatal("expected non-nil sdk client") + } + if got, want := sdkClient.URL.String(), "https://coder.example.com"; got != want { + t.Fatalf("expected client URL %q, got %q", want, got) + } + + _, err = staticProvider.ClientForNamespace(context.Background(), "default") + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest for namespace outside provider scope, got %v", err) + } +} diff --git a/internal/app/apiserverapp/integration_test.go b/internal/app/apiserverapp/integration_test.go new file mode 100644 index 00000000..e9ec819a --- /dev/null +++ b/internal/app/apiserverapp/integration_test.go @@ -0,0 +1,390 @@ +package apiserverapp + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "k8s.io/apimachinery/pkg/runtime/serializer" + genericoptions "k8s.io/apiserver/pkg/server/options" + + aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" + "github.com/coder/coder-k8s/internal/aggregated/coder" + "github.com/coder/coder/v2/codersdk" +) + +func TestIntegrationAggregatedAPIServerBootstrapAndList(t *testing.T) { + t.Parallel() + + mockCoder := newIntegrationMockCoderServer("test-token") + defer mockCoder.Close() + + mockCoderURLString := mockCoder.URL() + mockCoderURL, err := url.Parse(mockCoderURLString) + if err != nil { + t.Fatalf("parse mock coder URL %q: %v", mockCoderURLString, err) + } + + sdkClient := codersdk.New(mockCoderURL) + if sdkClient == nil { + t.Fatal("assertion failed: codersdk client must not be nil") + } + sdkClient.SetSessionToken("test-token") + + provider := &coder.StaticClientProvider{Client: sdkClient, Namespace: "test-ns"} + if provider.Client == nil { + t.Fatal("assertion failed: provider client must not be nil") + } + + scheme := NewScheme() + if scheme == nil { + t.Fatal("assertion failed: scheme must not be nil") + } + codecs := serializer.NewCodecFactory(scheme) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("create aggregated API listener: %v", err) + } + defer func() { + _ = listener.Close() + }() + + secureServingOptions := genericoptions.NewSecureServingOptions() + if secureServingOptions == nil { + t.Fatal("assertion failed: secure serving options must not be nil") + } + secureServingOptions.Listener = listener + secureServingOptions.BindPort = 0 + secureServingOptions.ServerCert.CertDirectory = "" + secureServingOptions.ServerCert.PairName = "" + + recommendedConfig, err := NewRecommendedConfig(scheme, codecs, secureServingOptions) + if err != nil { + t.Fatalf("build recommended config: %v", err) + } + if recommendedConfig == nil { + t.Fatal("assertion failed: recommended config must not be nil") + } + if recommendedConfig.LoopbackClientConfig == nil { + t.Fatal("assertion failed: loopback client config must not be nil") + } + if recommendedConfig.LoopbackClientConfig.Host == "" { + t.Fatal("assertion failed: loopback client host must not be empty") + } + + server, err := NewGenericAPIServer(recommendedConfig) + if err != nil { + t.Fatalf("build generic API server: %v", err) + } + if server == nil { + t.Fatal("assertion failed: generic API server must not be nil") + } + defer server.Destroy() + + apiGroupInfo, err := NewAPIGroupInfo(scheme, codecs, provider) + if err != nil { + t.Fatalf("build API group info: %v", err) + } + if apiGroupInfo == nil { + t.Fatal("assertion failed: API group info must not be nil") + } + if err := InstallAPIGroup(server, apiGroupInfo); err != nil { + t.Fatalf("install API group: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 1) + go func() { + errCh <- server.PrepareRun().RunWithContext(ctx) + }() + defer func() { + cancel() + select { + case runErr := <-errCh: + if runErr != nil && !errors.Is(runErr, context.Canceled) { + t.Errorf("aggregated API server exited with error: %v", runErr) + } + case <-time.After(5 * time.Second): + t.Error("timed out waiting for aggregated API server to stop") + } + }() + + httpClient := &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + //nolint:gosec // Integration test uses ephemeral self-signed certs. + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + + baseURL := strings.TrimSuffix(recommendedConfig.LoopbackClientConfig.Host, "/") + if baseURL == "" { + t.Fatal("assertion failed: base URL must not be empty") + } + + templateListURL := fmt.Sprintf( + "%s/apis/aggregation.coder.com/v1alpha1/namespaces/test-ns/codertemplates", + baseURL, + ) + workspaceListURL := fmt.Sprintf( + "%s/apis/aggregation.coder.com/v1alpha1/namespaces/test-ns/coderworkspaces", + baseURL, + ) + + var templateList aggregationv1alpha1.CoderTemplateList + mustGetJSONWithRetry(t, httpClient, errCh, templateListURL, &templateList) + if len(templateList.Items) != 1 { + t.Fatalf("expected 1 template, got %d", len(templateList.Items)) + } + if got := templateList.Items[0].Name; got != "default.my-template" { + t.Fatalf("expected template name default.my-template, got %q", got) + } + if got := templateList.Items[0].Namespace; got != "test-ns" { + t.Fatalf("expected template namespace test-ns, got %q", got) + } + + var workspaceList aggregationv1alpha1.CoderWorkspaceList + mustGetJSONWithRetry(t, httpClient, errCh, workspaceListURL, &workspaceList) + if len(workspaceList.Items) != 1 { + t.Fatalf("expected 1 workspace, got %d", len(workspaceList.Items)) + } + if got := workspaceList.Items[0].Name; got != "default.testuser.my-workspace" { + t.Fatalf("expected workspace name default.testuser.my-workspace, got %q", got) + } + if got := workspaceList.Items[0].Namespace; got != "test-ns" { + t.Fatalf("expected workspace namespace test-ns, got %q", got) + } +} + +func mustGetJSONWithRetry(t *testing.T, client *http.Client, errCh <-chan error, requestURL string, target any) { + t.Helper() + + if client == nil { + t.Fatal("assertion failed: HTTP client must not be nil") + } + if requestURL == "" { + t.Fatal("assertion failed: request URL must not be empty") + } + if target == nil { + t.Fatal("assertion failed: decode target must not be nil") + } + + deadline := time.Now().Add(10 * time.Second) + var lastErr error + + for time.Now().Before(deadline) { + select { + case runErr := <-errCh: + t.Fatalf("aggregated API server exited before request %q completed: %v", requestURL, runErr) + default: + } + + request, err := http.NewRequest(http.MethodGet, requestURL, nil) + if err != nil { + t.Fatalf("create request %q: %v", requestURL, err) + } + request.Header.Set("Accept", "application/json") + + response, err := client.Do(request) + if err != nil { + lastErr = err + time.Sleep(50 * time.Millisecond) + continue + } + + body, err := io.ReadAll(response.Body) + closeErr := response.Body.Close() + if err != nil { + t.Fatalf("read response body for %q: %v", requestURL, err) + } + if closeErr != nil { + t.Fatalf("close response body for %q: %v", requestURL, closeErr) + } + + if response.StatusCode == http.StatusOK { + if err := json.Unmarshal(body, target); err != nil { + t.Fatalf("decode response for %q: %v (body=%q)", requestURL, err, string(body)) + } + return + } + + lastErr = fmt.Errorf("unexpected status for %q: %d body=%s", requestURL, response.StatusCode, string(body)) + time.Sleep(50 * time.Millisecond) + } + + t.Fatalf("request %q did not succeed before timeout: %v", requestURL, lastErr) +} + +type integrationMockCoderServer struct { + server *httptest.Server +} + +func newIntegrationMockCoderServer(expectedSessionToken string) *integrationMockCoderServer { + if expectedSessionToken == "" { + panic("assertion failed: expected session token must not be empty") + } + + organizationID := uuid.MustParse("11111111-1111-1111-1111-111111111111") + templateID := uuid.MustParse("22222222-2222-2222-2222-222222222222") + templateVersionID := uuid.MustParse("33333333-3333-3333-3333-333333333333") + workspaceID := uuid.MustParse("44444444-4444-4444-4444-444444444444") + workspaceBuildID := uuid.MustParse("55555555-5555-5555-5555-555555555555") + now := time.Date(2024, time.January, 1, 0, 0, 0, 0, time.UTC) + + organization := codersdk.Organization{ + MinimalOrganization: codersdk.MinimalOrganization{ + ID: organizationID, + Name: "default", + }, + CreatedAt: now, + UpdatedAt: now, + } + + template := codersdk.Template{ + ID: templateID, + Name: "my-template", + OrganizationName: "default", + OrganizationID: organizationID, + ActiveVersionID: templateVersionID, + DisplayName: "My Template", + CreatedAt: now, + UpdatedAt: now, + } + + workspace := codersdk.Workspace{ + ID: workspaceID, + Name: "my-workspace", + OwnerName: "testuser", + OrganizationName: "default", + OrganizationID: organizationID, + TemplateName: "my-template", + TemplateID: templateID, + CreatedAt: now, + UpdatedAt: now, + LastUsedAt: now, + LatestBuild: codersdk.WorkspaceBuild{ + ID: workspaceBuildID, + Transition: codersdk.WorkspaceTransitionStart, + Status: codersdk.WorkspaceStatusRunning, + TemplateVersionID: templateVersionID, + CreatedAt: now, + UpdatedAt: now, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if token := r.Header.Get(codersdk.SessionTokenHeader); token != expectedSessionToken { + writeCoderError(w, http.StatusUnauthorized, fmt.Sprintf("unexpected session token %q", token)) + return + } + + segments := splitPath(r.URL.Path) + switch { + case r.Method == http.MethodGet && hasSegments(segments, "api", "v2", "organizations") && len(segments) == 4: + orgSegment := segments[3] + if orgSegment != organization.Name && orgSegment != organization.ID.String() { + writeCoderError(w, http.StatusNotFound, "organization not found") + return + } + writeJSON(w, http.StatusOK, organization) + return + case r.Method == http.MethodGet && hasSegments(segments, "api", "v2", "templates") && len(segments) == 3: + writeJSON(w, http.StatusOK, []codersdk.Template{template}) + return + case r.Method == http.MethodGet && hasSegments(segments, "api", "v2", "organizations") && len(segments) == 6 && segments[4] == "templates": + orgSegment := segments[3] + templateSegment := segments[5] + if orgSegment != organization.Name && orgSegment != organization.ID.String() { + writeCoderError(w, http.StatusNotFound, "organization not found") + return + } + if templateSegment != template.Name { + writeCoderError(w, http.StatusNotFound, "template not found") + return + } + writeJSON(w, http.StatusOK, template) + return + case r.Method == http.MethodGet && hasSegments(segments, "api", "v2", "workspaces") && len(segments) == 3: + writeJSON(w, http.StatusOK, codersdk.WorkspacesResponse{Workspaces: []codersdk.Workspace{workspace}, Count: 1}) + return + case r.Method == http.MethodGet && hasSegments(segments, "api", "v2", "users") && len(segments) == 6 && segments[4] == "workspace": + ownerSegment := segments[3] + workspaceSegment := segments[5] + if ownerSegment != workspace.OwnerName || workspaceSegment != workspace.Name { + writeCoderError(w, http.StatusNotFound, "workspace not found") + return + } + writeJSON(w, http.StatusOK, workspace) + return + default: + writeCoderError(w, http.StatusNotFound, fmt.Sprintf("unexpected route: %s %s", r.Method, r.URL.Path)) + return + } + })) + + return &integrationMockCoderServer{server: server} +} + +func (s *integrationMockCoderServer) URL() string { + if s == nil { + panic("assertion failed: integration mock coder server must not be nil") + } + if s.server == nil { + panic("assertion failed: integration mock coder server backing server must not be nil") + } + + return s.server.URL +} + +func (s *integrationMockCoderServer) Close() { + if s == nil || s.server == nil { + return + } + + s.server.Close() +} + +func splitPath(path string) []string { + trimmed := strings.Trim(path, "/") + if trimmed == "" { + return nil + } + + return strings.Split(trimmed, "/") +} + +func hasSegments(segments []string, expected ...string) bool { + if len(segments) < len(expected) { + return false + } + + for i, segment := range expected { + if segments[i] != segment { + return false + } + } + + return true +} + +func writeJSON(w http.ResponseWriter, statusCode int, payload any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + _ = json.NewEncoder(w).Encode(payload) +} + +func writeCoderError(w http.ResponseWriter, statusCode int, message string) { + writeJSON(w, statusCode, codersdk.Response{Message: message}) +} diff --git a/main_test.go b/main_test.go index d54a1026..1dada6b8 100644 --- a/main_test.go +++ b/main_test.go @@ -5,10 +5,12 @@ import ( "errors" "strings" "testing" + "time" "k8s.io/apimachinery/pkg/runtime/schema" coderv1alpha1 "github.com/coder/coder-k8s/api/v1alpha1" + "github.com/coder/coder-k8s/internal/app/apiserverapp" "github.com/coder/coder-k8s/internal/app/controllerapp" "github.com/coder/coder-k8s/internal/controller" ) @@ -102,15 +104,33 @@ func TestRunDispatchesAggregatedAPIServerMode(t *testing.T) { expectedErr := errors.New("sentinel aggregated-apiserver error") called := false - runAggregatedAPIServerApp = func(ctx context.Context) error { + runAggregatedAPIServerApp = func(ctx context.Context, opts apiserverapp.Options) error { called = true if ctx == nil { t.Fatal("expected non-nil context passed to aggregated apiserver runner") } + if got, want := opts.CoderURL, "https://coder.example.com"; got != want { + t.Fatalf("expected coder URL %q, got %q", want, got) + } + if got, want := opts.CoderSessionToken, "test-token"; got != want { + t.Fatalf("expected coder session token %q, got %q", want, got) + } + if got, want := opts.CoderNamespace, "control-plane"; got != want { + t.Fatalf("expected coder namespace %q, got %q", want, got) + } + if got, want := opts.CoderRequestTimeout, 45*time.Second; got != want { + t.Fatalf("expected coder request timeout %v, got %v", want, got) + } return expectedErr } - err := run([]string{"--app=aggregated-apiserver"}) + err := run([]string{ + "--app=aggregated-apiserver", + "--coder-url=https://coder.example.com", + "--coder-session-token=test-token", + "--coder-namespace=control-plane", + "--coder-request-timeout=45s", + }) if !called { t.Fatal("expected aggregated apiserver runner to be called") } @@ -119,6 +139,66 @@ func TestRunDispatchesAggregatedAPIServerMode(t *testing.T) { } } +func TestRunRejectsAggregatedAPIServerModeWithCoderURLMissingScheme(t *testing.T) { + t.Helper() + installMockSignalHandler(t) + + previous := runAggregatedAPIServerApp + t.Cleanup(func() { + runAggregatedAPIServerApp = previous + }) + + called := false + runAggregatedAPIServerApp = func(context.Context, apiserverapp.Options) error { + called = true + return nil + } + + err := run([]string{ + "--app=aggregated-apiserver", + "--coder-url=coder.example.com", + }) + if err == nil { + t.Fatal("expected an error when --coder-url omits scheme") + } + if !strings.Contains(err.Error(), "must include scheme and host") { + t.Fatalf("expected missing scheme/host validation error, got %v", err) + } + if called { + t.Fatal("expected aggregated apiserver runner not to be called on invalid --coder-url") + } +} + +func TestRunRejectsAggregatedAPIServerModeWithUnsupportedCoderURLScheme(t *testing.T) { + t.Helper() + installMockSignalHandler(t) + + previous := runAggregatedAPIServerApp + t.Cleanup(func() { + runAggregatedAPIServerApp = previous + }) + + called := false + runAggregatedAPIServerApp = func(context.Context, apiserverapp.Options) error { + called = true + return nil + } + + err := run([]string{ + "--app=aggregated-apiserver", + "--coder-url=ftp://coder.example.com", + }) + if err == nil { + t.Fatal("expected an error when --coder-url has unsupported scheme") + } + if !strings.Contains(err.Error(), "scheme must be http or https") { + t.Fatalf("expected scheme validation error, got %v", err) + } + if called { + t.Fatal("expected aggregated apiserver runner not to be called on invalid --coder-url") + } +} + func TestRunDispatchesMCPHTTPMode(t *testing.T) { t.Helper() installMockSignalHandler(t) diff --git a/scripts/check_pr_reviews.sh b/scripts/check_pr_reviews.sh index 49df5fb9..7933a6c9 100755 --- a/scripts/check_pr_reviews.sh +++ b/scripts/check_pr_reviews.sh @@ -47,7 +47,9 @@ GRAPHQL_QUERY='query($owner: String!, $repo: String!, $pr: Int!, $cursor: String }' THREAD_CURSOR="" -ALL_THREADS='[]' +ALL_THREADS_FILE=$(mktemp) +trap 'rm -f "$ALL_THREADS_FILE"' EXIT +printf '[]\n' >"$ALL_THREADS_FILE" while true; do if [ -n "$THREAD_CURSOR" ]; then @@ -71,7 +73,10 @@ while true; do fi PAGE_THREADS=$(echo "$RESULT" | jq '.data.repository.pullRequest.reviewThreads.nodes') - ALL_THREADS=$(jq -cn --argjson all "$ALL_THREADS" --argjson page "$PAGE_THREADS" '$all + $page') + + MERGED_THREADS_FILE=$(mktemp) + jq -s '.[0] + .[1]' "$ALL_THREADS_FILE" <(printf '%s\n' "$PAGE_THREADS") >"$MERGED_THREADS_FILE" + mv "$MERGED_THREADS_FILE" "$ALL_THREADS_FILE" HAS_NEXT=$(echo "$RESULT" | jq -r '.data.repository.pullRequest.reviewThreads.pageInfo.hasNextPage') if [ "$HAS_NEXT" != "true" ]; then @@ -85,7 +90,7 @@ while true; do fi done -UNRESOLVED=$(echo "$ALL_THREADS" | jq -c '.[] | select(.isResolved == false) | {thread_id: .id, user: (.comments.nodes[0].author.login // "unknown"), body: (.comments.nodes[0].body // ""), diff_hunk: (.comments.nodes[0].diffHunk // ""), commit_id: (.comments.nodes[0].commit.oid // "")}') +UNRESOLVED=$(jq -c '.[] | select(.isResolved == false) | {thread_id: .id, user: (.comments.nodes[0].author.login // "unknown"), body: (.comments.nodes[0].body // ""), diff_hunk: (.comments.nodes[0].diffHunk // ""), commit_id: (.comments.nodes[0].commit.oid // "")}' "$ALL_THREADS_FILE") if [ -n "$UNRESOLVED" ]; then echo "❌ Unresolved review comments found:"