diff --git a/deploy/rbac.yaml b/deploy/rbac.yaml index 94453edd..ca207922 100644 --- a/deploy/rbac.yaml +++ b/deploy/rbac.yaml @@ -84,9 +84,12 @@ rules: verbs: ["get", "list", "watch"] - apiGroups: ["aggregation.coder.com"] resources: ["coderworkspaces", "codertemplates"] + verbs: ["get", "list", "watch", "update", "patch"] + - apiGroups: ["apps"] + resources: ["deployments"] verbs: ["get", "list", "watch"] - apiGroups: [""] - resources: ["pods", "pods/log", "events", "namespaces"] + resources: ["services", "pods", "pods/log", "events", "namespaces"] verbs: ["get", "list", "watch"] --- apiVersion: rbac.authorization.k8s.io/v1 diff --git a/docs/how-to/mcp-server.md b/docs/how-to/mcp-server.md index 5c23c183..2c3a9c64 100644 --- a/docs/how-to/mcp-server.md +++ b/docs/how-to/mcp-server.md @@ -6,11 +6,12 @@ The MCP server runs in HTTP mode (`--app=mcp-http`). ## 1. Overview -The MCP server provides tools for inspecting Kubernetes resources managed by `coder-k8s`, including: +The MCP server provides tools for inspecting and updating Kubernetes resources managed by `coder-k8s`, including: - `CoderControlPlane` resources -- `CoderWorkspace` resources -- `CoderTemplate` resources +- Control-plane Deployment, Service, and Pod status +- `CoderWorkspace` resources (including `spec.running` updates) +- `CoderTemplate` resources (including `spec.running` updates) - Namespace events - Pod logs @@ -38,12 +39,22 @@ http://127.0.0.1:8090/mcp ## 3. Available tools -The server exposes MCP tools for: - -- Reading `CoderControlPlane` resources and status -- Listing `CoderWorkspace` and `CoderTemplate` resources -- Listing namespace events for troubleshooting -- Reading pod logs for debugging +The server exposes the following MCP tools: + +- `list_control_planes` +- `get_control_plane_status` +- `list_control_plane_pods` +- `get_control_plane_deployment_status` +- `get_service_status` +- `list_workspaces` +- `get_workspace` +- `set_workspace_running` +- `list_templates` +- `get_template` +- `set_template_running` +- `get_events` +- `get_pod_logs` +- `check_health` ## 4. Health checks diff --git a/internal/aggregated/storage/helpers.go b/internal/aggregated/storage/helpers.go new file mode 100644 index 00000000..1f07e7c9 --- /dev/null +++ b/internal/aggregated/storage/helpers.go @@ -0,0 +1,43 @@ +package storage + +import ( + "context" + "fmt" + "strconv" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + genericapirequest "k8s.io/apiserver/pkg/endpoints/request" +) + +func resolveWriteNamespace(ctx context.Context, objectNamespace string) (string, error) { + requestNamespace := genericapirequest.NamespaceValue(ctx) + if requestNamespace == "" && objectNamespace == "" { + return "", apierrors.NewBadRequest("namespace is required") + } + if requestNamespace == "" { + return objectNamespace, nil + } + if objectNamespace == "" { + return requestNamespace, nil + } + if requestNamespace != objectNamespace { + return "", apierrors.NewBadRequest(fmt.Sprintf("request namespace %q does not match object namespace %q", requestNamespace, objectNamespace)) + } + return requestNamespace, nil +} + +func incrementResourceVersion(resourceVersion string) (string, error) { + if resourceVersion == "" { + return "1", nil + } + + version, err := strconv.ParseInt(resourceVersion, 10, 64) + if err != nil { + return "", fmt.Errorf("assertion failed: invalid resourceVersion %q: %w", resourceVersion, err) + } + if version < 0 { + return "", fmt.Errorf("assertion failed: resourceVersion must not be negative: %d", version) + } + + return strconv.FormatInt(version+1, 10), nil +} diff --git a/internal/aggregated/storage/template.go b/internal/aggregated/storage/template.go index 0bb93867..3630104b 100644 --- a/internal/aggregated/storage/template.go +++ b/internal/aggregated/storage/template.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sort" + "sync" "time" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -20,12 +21,15 @@ var ( _ rest.Storage = (*TemplateStorage)(nil) _ rest.Getter = (*TemplateStorage)(nil) _ rest.Lister = (*TemplateStorage)(nil) + _ rest.Updater = (*TemplateStorage)(nil) + _ rest.GracefulDeleter = (*TemplateStorage)(nil) _ rest.Scoper = (*TemplateStorage)(nil) _ rest.SingularNameProvider = (*TemplateStorage)(nil) ) // TemplateStorage provides hardcoded CoderTemplate objects. type TemplateStorage struct { + mu sync.RWMutex tableConvertor rest.TableConvertor templates map[string]*aggregationv1alpha1.CoderTemplate } @@ -45,8 +49,10 @@ func NewTemplateStorage() *TemplateStorage { APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), }, ObjectMeta: metav1.ObjectMeta{ - Name: "starter-template", - Namespace: "default", + Name: "starter-template", + Namespace: "default", + ResourceVersion: "1", + Generation: 1, }, Spec: aggregationv1alpha1.CoderTemplateSpec{Running: true}, Status: aggregationv1alpha1.CoderTemplateStatus{ @@ -59,8 +65,10 @@ func NewTemplateStorage() *TemplateStorage { APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), }, ObjectMeta: metav1.ObjectMeta{ - Name: "platform-template", - Namespace: "default", + Name: "platform-template", + Namespace: "default", + ResourceVersion: "1", + Generation: 1, }, Spec: aggregationv1alpha1.CoderTemplateSpec{Running: false}, Status: aggregationv1alpha1.CoderTemplateStatus{ @@ -73,8 +81,10 @@ func NewTemplateStorage() *TemplateStorage { APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), }, ObjectMeta: metav1.ObjectMeta{ - Name: "docs-template", - Namespace: "sandbox", + Name: "docs-template", + Namespace: "sandbox", + ResourceVersion: "1", + Generation: 1, }, Spec: aggregationv1alpha1.CoderTemplateSpec{Running: true}, Status: aggregationv1alpha1.CoderTemplateStatus{ @@ -125,20 +135,26 @@ func (s *TemplateStorage) Get(ctx context.Context, name string, _ *metav1.GetOpt } namespace := genericapirequest.NamespaceValue(ctx) + + s.mu.RLock() + defer s.mu.RUnlock() + if namespace != "" { - if template, ok := s.templates[templateKey(namespace, name)]; ok { - return template.DeepCopy(), nil + template, ok := s.templates[templateKey(namespace, name)] + if !ok { + return nil, apierrors.NewNotFound(aggregationv1alpha1.Resource("codertemplates"), name) } - return nil, apierrors.NewNotFound(aggregationv1alpha1.Resource("codertemplates"), name) + return template.DeepCopy(), nil } - for _, template := range s.templates { - if template.Name == name { - return template.DeepCopy(), nil - } + template, found, ambiguous := s.findTemplateByNameLocked(name) + if ambiguous { + return nil, apierrors.NewBadRequest(fmt.Sprintf("template name %q is ambiguous across namespaces; specify namespace", name)) } - - return nil, apierrors.NewNotFound(aggregationv1alpha1.Resource("codertemplates"), name) + if !found { + return nil, apierrors.NewNotFound(aggregationv1alpha1.Resource("codertemplates"), name) + } + return template.DeepCopy(), nil } // List returns hardcoded CoderTemplate objects. @@ -156,9 +172,12 @@ func (s *TemplateStorage) List(ctx context.Context, _ *metainternalversion.ListO Kind: "CoderTemplateList", APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), }, - Items: make([]aggregationv1alpha1.CoderTemplate, 0, len(s.templates)), + Items: make([]aggregationv1alpha1.CoderTemplate, 0), } + s.mu.RLock() + defer s.mu.RUnlock() + keys := make([]string, 0, len(s.templates)) for key := range s.templates { keys = append(keys, key) @@ -176,6 +195,268 @@ func (s *TemplateStorage) List(ctx context.Context, _ *metainternalversion.ListO return list, nil } +// Create inserts a CoderTemplate into the in-memory store. +func (s *TemplateStorage) Create( + ctx context.Context, + obj runtime.Object, + createValidation rest.ValidateObjectFunc, + _ *metav1.CreateOptions, +) (runtime.Object, error) { + if s == nil { + return nil, fmt.Errorf("assertion failed: template storage must not be nil") + } + if ctx == nil { + return nil, fmt.Errorf("assertion failed: context must not be nil") + } + if obj == nil { + return nil, fmt.Errorf("assertion failed: object must not be nil") + } + + template, ok := obj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + return nil, apierrors.NewBadRequest(fmt.Sprintf("expected *CoderTemplate, got %T", obj)) + } + + candidate := template.DeepCopy() + if candidate.Name == "" { + return nil, apierrors.NewBadRequest("metadata.name is required") + } + + namespace, err := resolveWriteNamespace(ctx, candidate.Namespace) + if err != nil { + return nil, err + } + candidate.Namespace = namespace + + ensureTemplateTypeMeta(candidate) + if candidate.Generation == 0 { + candidate.Generation = 1 + } + if candidate.CreationTimestamp.IsZero() { + candidate.CreationTimestamp = metav1.Now() + } + candidate.ResourceVersion = "1" + + if createValidation != nil { + if err := createValidation(ctx, candidate); err != nil { + return nil, err + } + } + + key := templateKey(candidate.Namespace, candidate.Name) + + s.mu.Lock() + defer s.mu.Unlock() + + if _, exists := s.templates[key]; exists { + return nil, apierrors.NewAlreadyExists(aggregationv1alpha1.Resource("codertemplates"), candidate.Name) + } + + s.templates[key] = candidate.DeepCopy() + return candidate.DeepCopy(), nil +} + +// Update modifies an existing CoderTemplate in the in-memory store. +func (s *TemplateStorage) Update( + ctx context.Context, + name string, + objInfo rest.UpdatedObjectInfo, + createValidation rest.ValidateObjectFunc, + updateValidation rest.ValidateObjectUpdateFunc, + forceAllowCreate bool, + _ *metav1.UpdateOptions, +) (runtime.Object, bool, error) { + if s == nil { + return nil, false, fmt.Errorf("assertion failed: template storage must not be nil") + } + if ctx == nil { + return nil, false, fmt.Errorf("assertion failed: context must not be nil") + } + if name == "" { + return nil, false, fmt.Errorf("assertion failed: template name must not be empty") + } + if objInfo == nil { + return nil, false, fmt.Errorf("assertion failed: updated object info must not be nil") + } + + namespace := genericapirequest.NamespaceValue(ctx) + if namespace == "" { + return nil, false, apierrors.NewBadRequest("namespace is required") + } + + key := templateKey(namespace, name) + + s.mu.Lock() + defer s.mu.Unlock() + + existing, exists := s.templates[key] + if !exists { + if !forceAllowCreate { + return nil, false, apierrors.NewNotFound(aggregationv1alpha1.Resource("codertemplates"), name) + } + + createdObj, err := objInfo.UpdatedObject(ctx, &aggregationv1alpha1.CoderTemplate{}) + if err != nil { + return nil, false, err + } + createdTemplate, ok := createdObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + return nil, false, apierrors.NewBadRequest(fmt.Sprintf("expected *CoderTemplate, got %T", createdObj)) + } + + candidate := createdTemplate.DeepCopy() + if candidate.Name == "" { + candidate.Name = name + } + if candidate.Name != name { + return nil, false, apierrors.NewBadRequest(fmt.Sprintf("metadata.name %q must match request name %q", candidate.Name, name)) + } + if candidate.Namespace == "" { + candidate.Namespace = namespace + } + if candidate.Namespace != namespace { + return nil, false, apierrors.NewBadRequest( + fmt.Sprintf("metadata.namespace %q must match request namespace %q", candidate.Namespace, namespace), + ) + } + + ensureTemplateTypeMeta(candidate) + if candidate.Generation == 0 { + candidate.Generation = 1 + } + if candidate.CreationTimestamp.IsZero() { + candidate.CreationTimestamp = metav1.Now() + } + candidate.ResourceVersion = "1" + + if createValidation != nil { + if err := createValidation(ctx, candidate); err != nil { + return nil, false, err + } + } + + s.templates[key] = candidate.DeepCopy() + return candidate.DeepCopy(), true, nil + } + + updatedObj, err := objInfo.UpdatedObject(ctx, existing.DeepCopy()) + if err != nil { + return nil, false, err + } + updatedTemplate, ok := updatedObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + return nil, false, apierrors.NewBadRequest(fmt.Sprintf("expected *CoderTemplate, got %T", updatedObj)) + } + + candidate := updatedTemplate.DeepCopy() + if candidate.Name == "" { + candidate.Name = name + } + if candidate.Name != name { + return nil, false, apierrors.NewBadRequest(fmt.Sprintf("metadata.name %q must match request name %q", candidate.Name, name)) + } + if candidate.Namespace == "" { + candidate.Namespace = namespace + } + if candidate.Namespace != namespace { + return nil, false, apierrors.NewBadRequest( + fmt.Sprintf("metadata.namespace %q must match request namespace %q", candidate.Namespace, namespace), + ) + } + + if candidate.ResourceVersion == "" { + return nil, false, apierrors.NewBadRequest("metadata.resourceVersion is required for update") + } + if candidate.ResourceVersion != existing.ResourceVersion { + return nil, false, apierrors.NewConflict( + aggregationv1alpha1.Resource("codertemplates"), + name, + fmt.Errorf("resourceVersion %q does not match current value %q", candidate.ResourceVersion, existing.ResourceVersion), + ) + } + + candidate.Status = existing.Status + candidate.CreationTimestamp = existing.CreationTimestamp + candidate.Generation = existing.Generation + 1 + candidateFinalResourceVersion, err := incrementResourceVersion(existing.ResourceVersion) + if err != nil { + return nil, false, err + } + candidate.ResourceVersion = candidateFinalResourceVersion + ensureTemplateTypeMeta(candidate) + + if updateValidation != nil { + if err := updateValidation(ctx, candidate, existing); err != nil { + return nil, false, err + } + } + + s.templates[key] = candidate.DeepCopy() + return candidate.DeepCopy(), false, nil +} + +// Delete removes a CoderTemplate from the in-memory store. +func (s *TemplateStorage) Delete( + ctx context.Context, + name string, + deleteValidation rest.ValidateObjectFunc, + _ *metav1.DeleteOptions, +) (runtime.Object, bool, error) { + if s == nil { + return nil, false, fmt.Errorf("assertion failed: template storage must not be nil") + } + if ctx == nil { + return nil, false, fmt.Errorf("assertion failed: context must not be nil") + } + if name == "" { + return nil, false, fmt.Errorf("assertion failed: template name must not be empty") + } + + namespace := genericapirequest.NamespaceValue(ctx) + + s.mu.Lock() + defer s.mu.Unlock() + + var ( + key string + template *aggregationv1alpha1.CoderTemplate + ) + if namespace != "" { + key = templateKey(namespace, name) + template = s.templates[key] + } else { + matchedKeys := make([]string, 0) + for candidateKey, candidateTemplate := range s.templates { + if candidateTemplate.Name == name { + matchedKeys = append(matchedKeys, candidateKey) + } + } + if len(matchedKeys) > 1 { + return nil, false, apierrors.NewBadRequest( + fmt.Sprintf("template name %q is ambiguous across namespaces; specify namespace", name), + ) + } + if len(matchedKeys) == 1 { + key = matchedKeys[0] + template = s.templates[key] + } + } + + if template == nil { + return nil, false, apierrors.NewNotFound(aggregationv1alpha1.Resource("codertemplates"), name) + } + + if deleteValidation != nil { + if err := deleteValidation(ctx, template.DeepCopy()); err != nil { + return nil, false, err + } + } + + deleted := template.DeepCopy() + delete(s.templates, key) + return deleted, true, nil +} + // ConvertToTable converts a template object or list into kubectl table output. func (s *TemplateStorage) ConvertToTable(ctx context.Context, object, tableOptions runtime.Object) (*metav1.Table, error) { if s == nil { @@ -187,3 +468,33 @@ func (s *TemplateStorage) ConvertToTable(ctx context.Context, object, tableOptio return s.tableConvertor.ConvertToTable(ctx, object, tableOptions) } + +func ensureTemplateTypeMeta(template *aggregationv1alpha1.CoderTemplate) { + if template == nil { + panic("assertion failed: template must not be nil") + } + template.TypeMeta = metav1.TypeMeta{ + Kind: "CoderTemplate", + APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), + } +} + +func (s *TemplateStorage) findTemplateByNameLocked(name string) (*aggregationv1alpha1.CoderTemplate, bool, bool) { + matchedKeys := make([]string, 0) + for key, template := range s.templates { + if template.Name == name { + matchedKeys = append(matchedKeys, key) + } + } + if len(matchedKeys) == 0 { + return nil, false, false + } + if len(matchedKeys) > 1 { + return nil, false, true + } + template := s.templates[matchedKeys[0]] + if template == nil { + return nil, false, false + } + return template, true, false +} diff --git a/internal/aggregated/storage/template_test.go b/internal/aggregated/storage/template_test.go new file mode 100644 index 00000000..73d2bec7 --- /dev/null +++ b/internal/aggregated/storage/template_test.go @@ -0,0 +1,231 @@ +package storage + +import ( + "context" + "testing" + "time" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + genericapirequest "k8s.io/apiserver/pkg/endpoints/request" + "k8s.io/apiserver/pkg/registry/rest" + + aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" +) + +func TestTemplateStorageCRUDLifecycle(t *testing.T) { + t.Helper() + + templateStorage := NewTemplateStorage() + ctx := genericapirequest.WithNamespace(context.Background(), "default") + + createdObj, err := templateStorage.Create(ctx, &aggregationv1alpha1.CoderTemplate{ + ObjectMeta: metav1.ObjectMeta{Name: "unit-template"}, + Spec: aggregationv1alpha1.CoderTemplateSpec{Running: true}, + }, rest.ValidateAllObjectFunc, nil) + if err != nil { + t.Fatalf("create template: %v", err) + } + + created, ok := createdObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + t.Fatalf("expected *CoderTemplate from create, got %T", createdObj) + } + if created.Namespace != "default" { + t.Fatalf("expected namespace default, got %q", created.Namespace) + } + if created.ResourceVersion != "1" { + t.Fatalf("expected resourceVersion 1, got %q", created.ResourceVersion) + } + if created.Generation != 1 { + t.Fatalf("expected generation 1, got %d", created.Generation) + } + + toUpdate := created.DeepCopy() + toUpdate.Spec.Running = false + toUpdate.ResourceVersion = created.ResourceVersion + updatedObj, createdOnUpdate, err := templateStorage.Update( + ctx, + toUpdate.Name, + rest.DefaultUpdatedObjectInfo(toUpdate), + nil, + nil, + false, + nil, + ) + if err != nil { + t.Fatalf("update template: %v", err) + } + if createdOnUpdate { + t.Fatal("expected update of existing template, got createdOnUpdate=true") + } + + updated, ok := updatedObj.(*aggregationv1alpha1.CoderTemplate) + if !ok { + t.Fatalf("expected *CoderTemplate from update, got %T", updatedObj) + } + if updated.Spec.Running { + t.Fatalf("expected running=false after update, got %+v", updated.Spec) + } + if updated.ResourceVersion == created.ResourceVersion { + t.Fatalf("expected resourceVersion to change, got %q", updated.ResourceVersion) + } + if updated.Generation != created.Generation+1 { + t.Fatalf("expected generation increment to %d, got %d", created.Generation+1, updated.Generation) + } + + deletedObj, deletedNow, err := templateStorage.Delete(ctx, created.Name, nil, nil) + if err != nil { + t.Fatalf("delete template: %v", err) + } + if !deletedNow { + t.Fatal("expected immediate delete") + } + if _, ok := deletedObj.(*aggregationv1alpha1.CoderTemplate); !ok { + t.Fatalf("expected *CoderTemplate from delete, got %T", deletedObj) + } + + _, err = templateStorage.Get(ctx, created.Name, nil) + if !apierrors.IsNotFound(err) { + t.Fatalf("expected NotFound after delete, got %v", err) + } +} + +func TestTemplateStorageCreateAlreadyExists(t *testing.T) { + t.Helper() + + templateStorage := NewTemplateStorage() + ctx := genericapirequest.WithNamespace(context.Background(), "default") + + _, err := templateStorage.Create(ctx, &aggregationv1alpha1.CoderTemplate{ + ObjectMeta: metav1.ObjectMeta{Name: "starter-template"}, + Spec: aggregationv1alpha1.CoderTemplateSpec{Running: true}, + }, nil, nil) + if !apierrors.IsAlreadyExists(err) { + t.Fatalf("expected AlreadyExists error, got %v", err) + } +} + +func TestTemplateStorageUpdateRejectsNamespaceChange(t *testing.T) { + t.Helper() + + templateStorage := NewTemplateStorage() + ctx := genericapirequest.WithNamespace(context.Background(), "default") + + currentObj, err := templateStorage.Get(ctx, "starter-template", nil) + if err != nil { + t.Fatalf("get template: %v", err) + } + current := currentObj.(*aggregationv1alpha1.CoderTemplate) + + modified := current.DeepCopy() + modified.Namespace = "sandbox" + modified.ResourceVersion = current.ResourceVersion + + _, _, err = templateStorage.Update( + ctx, + modified.Name, + rest.DefaultUpdatedObjectInfo(modified), + nil, + nil, + false, + nil, + ) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest for namespace mismatch, got %v", err) + } +} + +func TestTemplateStorageUpdateIgnoresStatusWrites(t *testing.T) { + t.Helper() + + templateStorage := NewTemplateStorage() + ctx := genericapirequest.WithNamespace(context.Background(), "default") + + currentObj, err := templateStorage.Get(ctx, "starter-template", nil) + if err != nil { + t.Fatalf("get template: %v", err) + } + current := currentObj.(*aggregationv1alpha1.CoderTemplate) + if current.Status.AutoShutdown == nil { + t.Fatal("expected seeded template status autoShutdown") + } + + modified := current.DeepCopy() + modified.Spec.Running = !current.Spec.Running + modified.ResourceVersion = current.ResourceVersion + overrideDeadline := metav1.NewTime(time.Date(2040, time.January, 1, 0, 0, 0, 0, time.UTC)) + modified.Status.AutoShutdown = &overrideDeadline + + updatedObj, _, err := templateStorage.Update( + ctx, + modified.Name, + rest.DefaultUpdatedObjectInfo(modified), + nil, + nil, + false, + nil, + ) + if err != nil { + t.Fatalf("update template: %v", err) + } + + updated := updatedObj.(*aggregationv1alpha1.CoderTemplate) + if updated.Status.AutoShutdown == nil { + t.Fatal("expected status autoShutdown to remain present") + } + if !updated.Status.AutoShutdown.Equal(current.Status.AutoShutdown) { + t.Fatalf("expected status to remain unchanged, got %s want %s", updated.Status.AutoShutdown, current.Status.AutoShutdown) + } +} + +func TestTemplateStorageDeleteAmbiguousWithoutNamespace(t *testing.T) { + t.Helper() + + templateStorage := NewTemplateStorage() + + _, err := templateStorage.Create( + genericapirequest.WithNamespace(context.Background(), "sandbox"), + &aggregationv1alpha1.CoderTemplate{ObjectMeta: metav1.ObjectMeta{Name: "starter-template"}}, + nil, + nil, + ) + if err != nil { + t.Fatalf("seed same-name template in sandbox namespace: %v", err) + } + + _, _, err = templateStorage.Delete(context.Background(), "starter-template", nil, nil) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest for ambiguous delete, got %v", err) + } +} + +func TestTemplateStorageUpdateRequiresResourceVersion(t *testing.T) { + t.Helper() + + templateStorage := NewTemplateStorage() + ctx := genericapirequest.WithNamespace(context.Background(), "default") + + currentObj, err := templateStorage.Get(ctx, "starter-template", nil) + if err != nil { + t.Fatalf("get template: %v", err) + } + current := currentObj.(*aggregationv1alpha1.CoderTemplate) + + modified := current.DeepCopy() + modified.Spec.Running = !current.Spec.Running + modified.ResourceVersion = "" + + _, _, err = templateStorage.Update( + ctx, + modified.Name, + rest.DefaultUpdatedObjectInfo(modified), + nil, + nil, + false, + nil, + ) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest when resourceVersion is missing, got %v", err) + } +} diff --git a/internal/aggregated/storage/workspace.go b/internal/aggregated/storage/workspace.go index d92a185e..522d1d1e 100644 --- a/internal/aggregated/storage/workspace.go +++ b/internal/aggregated/storage/workspace.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "sort" + "sync" "time" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -21,12 +22,15 @@ var ( _ rest.Storage = (*WorkspaceStorage)(nil) _ rest.Getter = (*WorkspaceStorage)(nil) _ rest.Lister = (*WorkspaceStorage)(nil) + _ rest.Updater = (*WorkspaceStorage)(nil) + _ rest.GracefulDeleter = (*WorkspaceStorage)(nil) _ rest.Scoper = (*WorkspaceStorage)(nil) _ rest.SingularNameProvider = (*WorkspaceStorage)(nil) ) // WorkspaceStorage provides hardcoded CoderWorkspace objects. type WorkspaceStorage struct { + mu sync.RWMutex tableConvertor rest.TableConvertor workspaces map[string]*aggregationv1alpha1.CoderWorkspace } @@ -46,8 +50,10 @@ func NewWorkspaceStorage() *WorkspaceStorage { APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), }, ObjectMeta: metav1.ObjectMeta{ - Name: "dev-workspace", - Namespace: "default", + Name: "dev-workspace", + Namespace: "default", + ResourceVersion: "1", + Generation: 1, }, Spec: aggregationv1alpha1.CoderWorkspaceSpec{Running: true}, Status: aggregationv1alpha1.CoderWorkspaceStatus{ @@ -60,8 +66,10 @@ func NewWorkspaceStorage() *WorkspaceStorage { APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), }, ObjectMeta: metav1.ObjectMeta{ - Name: "staging-workspace", - Namespace: "default", + Name: "staging-workspace", + Namespace: "default", + ResourceVersion: "1", + Generation: 1, }, Spec: aggregationv1alpha1.CoderWorkspaceSpec{Running: false}, Status: aggregationv1alpha1.CoderWorkspaceStatus{ @@ -74,8 +82,10 @@ func NewWorkspaceStorage() *WorkspaceStorage { APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), }, ObjectMeta: metav1.ObjectMeta{ - Name: "sandbox-workspace", - Namespace: "sandbox", + Name: "sandbox-workspace", + Namespace: "sandbox", + ResourceVersion: "1", + Generation: 1, }, Spec: aggregationv1alpha1.CoderWorkspaceSpec{Running: true}, Status: aggregationv1alpha1.CoderWorkspaceStatus{ @@ -126,20 +136,26 @@ func (s *WorkspaceStorage) Get(ctx context.Context, name string, _ *metav1.GetOp } namespace := genericapirequest.NamespaceValue(ctx) + + s.mu.RLock() + defer s.mu.RUnlock() + if namespace != "" { - if workspace, ok := s.workspaces[workspaceKey(namespace, name)]; ok { - return workspace.DeepCopy(), nil + workspace, ok := s.workspaces[workspaceKey(namespace, name)] + if !ok { + return nil, apierrors.NewNotFound(aggregationv1alpha1.Resource("coderworkspaces"), name) } - return nil, apierrors.NewNotFound(aggregationv1alpha1.Resource("coderworkspaces"), name) + return workspace.DeepCopy(), nil } - for _, workspace := range s.workspaces { - if workspace.Name == name { - return workspace.DeepCopy(), nil - } + workspace, found, ambiguous := s.findWorkspaceByNameLocked(name) + if ambiguous { + return nil, apierrors.NewBadRequest(fmt.Sprintf("workspace name %q is ambiguous across namespaces; specify namespace", name)) } - - return nil, apierrors.NewNotFound(aggregationv1alpha1.Resource("coderworkspaces"), name) + if !found { + return nil, apierrors.NewNotFound(aggregationv1alpha1.Resource("coderworkspaces"), name) + } + return workspace.DeepCopy(), nil } // List returns hardcoded CoderWorkspace objects. @@ -157,9 +173,12 @@ func (s *WorkspaceStorage) List(ctx context.Context, _ *metainternalversion.List Kind: "CoderWorkspaceList", APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), }, - Items: make([]aggregationv1alpha1.CoderWorkspace, 0, len(s.workspaces)), + Items: make([]aggregationv1alpha1.CoderWorkspace, 0), } + s.mu.RLock() + defer s.mu.RUnlock() + keys := make([]string, 0, len(s.workspaces)) for key := range s.workspaces { keys = append(keys, key) @@ -177,6 +196,268 @@ func (s *WorkspaceStorage) List(ctx context.Context, _ *metainternalversion.List return list, nil } +// Create inserts a CoderWorkspace into the in-memory store. +func (s *WorkspaceStorage) Create( + ctx context.Context, + obj runtime.Object, + createValidation rest.ValidateObjectFunc, + _ *metav1.CreateOptions, +) (runtime.Object, error) { + if s == nil { + return nil, fmt.Errorf("assertion failed: workspace storage must not be nil") + } + if ctx == nil { + return nil, fmt.Errorf("assertion failed: context must not be nil") + } + if obj == nil { + return nil, fmt.Errorf("assertion failed: object must not be nil") + } + + workspace, ok := obj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + return nil, apierrors.NewBadRequest(fmt.Sprintf("expected *CoderWorkspace, got %T", obj)) + } + + candidate := workspace.DeepCopy() + if candidate.Name == "" { + return nil, apierrors.NewBadRequest("metadata.name is required") + } + + namespace, err := resolveWriteNamespace(ctx, candidate.Namespace) + if err != nil { + return nil, err + } + candidate.Namespace = namespace + + ensureWorkspaceTypeMeta(candidate) + if candidate.Generation == 0 { + candidate.Generation = 1 + } + if candidate.CreationTimestamp.IsZero() { + candidate.CreationTimestamp = metav1.Now() + } + candidate.ResourceVersion = "1" + + if createValidation != nil { + if err := createValidation(ctx, candidate); err != nil { + return nil, err + } + } + + key := workspaceKey(candidate.Namespace, candidate.Name) + + s.mu.Lock() + defer s.mu.Unlock() + + if _, exists := s.workspaces[key]; exists { + return nil, apierrors.NewAlreadyExists(aggregationv1alpha1.Resource("coderworkspaces"), candidate.Name) + } + + s.workspaces[key] = candidate.DeepCopy() + return candidate.DeepCopy(), nil +} + +// Update modifies an existing CoderWorkspace in the in-memory store. +func (s *WorkspaceStorage) Update( + ctx context.Context, + name string, + objInfo rest.UpdatedObjectInfo, + createValidation rest.ValidateObjectFunc, + updateValidation rest.ValidateObjectUpdateFunc, + forceAllowCreate bool, + _ *metav1.UpdateOptions, +) (runtime.Object, bool, error) { + if s == nil { + return nil, false, fmt.Errorf("assertion failed: workspace storage must not be nil") + } + if ctx == nil { + return nil, false, fmt.Errorf("assertion failed: context must not be nil") + } + if name == "" { + return nil, false, fmt.Errorf("assertion failed: workspace name must not be empty") + } + if objInfo == nil { + return nil, false, fmt.Errorf("assertion failed: updated object info must not be nil") + } + + namespace := genericapirequest.NamespaceValue(ctx) + if namespace == "" { + return nil, false, apierrors.NewBadRequest("namespace is required") + } + + key := workspaceKey(namespace, name) + + s.mu.Lock() + defer s.mu.Unlock() + + existing, exists := s.workspaces[key] + if !exists { + if !forceAllowCreate { + return nil, false, apierrors.NewNotFound(aggregationv1alpha1.Resource("coderworkspaces"), name) + } + + createdObj, err := objInfo.UpdatedObject(ctx, &aggregationv1alpha1.CoderWorkspace{}) + if err != nil { + return nil, false, err + } + createdWorkspace, ok := createdObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + return nil, false, apierrors.NewBadRequest(fmt.Sprintf("expected *CoderWorkspace, got %T", createdObj)) + } + + candidate := createdWorkspace.DeepCopy() + if candidate.Name == "" { + candidate.Name = name + } + if candidate.Name != name { + return nil, false, apierrors.NewBadRequest(fmt.Sprintf("metadata.name %q must match request name %q", candidate.Name, name)) + } + if candidate.Namespace == "" { + candidate.Namespace = namespace + } + if candidate.Namespace != namespace { + return nil, false, apierrors.NewBadRequest( + fmt.Sprintf("metadata.namespace %q must match request namespace %q", candidate.Namespace, namespace), + ) + } + + ensureWorkspaceTypeMeta(candidate) + if candidate.Generation == 0 { + candidate.Generation = 1 + } + if candidate.CreationTimestamp.IsZero() { + candidate.CreationTimestamp = metav1.Now() + } + candidate.ResourceVersion = "1" + + if createValidation != nil { + if err := createValidation(ctx, candidate); err != nil { + return nil, false, err + } + } + + s.workspaces[key] = candidate.DeepCopy() + return candidate.DeepCopy(), true, nil + } + + updatedObj, err := objInfo.UpdatedObject(ctx, existing.DeepCopy()) + if err != nil { + return nil, false, err + } + updatedWorkspace, ok := updatedObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + return nil, false, apierrors.NewBadRequest(fmt.Sprintf("expected *CoderWorkspace, got %T", updatedObj)) + } + + candidate := updatedWorkspace.DeepCopy() + if candidate.Name == "" { + candidate.Name = name + } + if candidate.Name != name { + return nil, false, apierrors.NewBadRequest(fmt.Sprintf("metadata.name %q must match request name %q", candidate.Name, name)) + } + if candidate.Namespace == "" { + candidate.Namespace = namespace + } + if candidate.Namespace != namespace { + return nil, false, apierrors.NewBadRequest( + fmt.Sprintf("metadata.namespace %q must match request namespace %q", candidate.Namespace, namespace), + ) + } + + if candidate.ResourceVersion == "" { + return nil, false, apierrors.NewBadRequest("metadata.resourceVersion is required for update") + } + if candidate.ResourceVersion != existing.ResourceVersion { + return nil, false, apierrors.NewConflict( + aggregationv1alpha1.Resource("coderworkspaces"), + name, + fmt.Errorf("resourceVersion %q does not match current value %q", candidate.ResourceVersion, existing.ResourceVersion), + ) + } + + candidate.Status = existing.Status + candidate.CreationTimestamp = existing.CreationTimestamp + candidate.Generation = existing.Generation + 1 + candidateFinalResourceVersion, err := incrementResourceVersion(existing.ResourceVersion) + if err != nil { + return nil, false, err + } + candidate.ResourceVersion = candidateFinalResourceVersion + ensureWorkspaceTypeMeta(candidate) + + if updateValidation != nil { + if err := updateValidation(ctx, candidate, existing); err != nil { + return nil, false, err + } + } + + s.workspaces[key] = candidate.DeepCopy() + return candidate.DeepCopy(), false, nil +} + +// Delete removes a CoderWorkspace from the in-memory store. +func (s *WorkspaceStorage) Delete( + ctx context.Context, + name string, + deleteValidation rest.ValidateObjectFunc, + _ *metav1.DeleteOptions, +) (runtime.Object, bool, error) { + if s == nil { + return nil, false, fmt.Errorf("assertion failed: workspace storage must not be nil") + } + if ctx == nil { + return nil, false, fmt.Errorf("assertion failed: context must not be nil") + } + if name == "" { + return nil, false, fmt.Errorf("assertion failed: workspace name must not be empty") + } + + namespace := genericapirequest.NamespaceValue(ctx) + + s.mu.Lock() + defer s.mu.Unlock() + + var ( + key string + workspace *aggregationv1alpha1.CoderWorkspace + ) + if namespace != "" { + key = workspaceKey(namespace, name) + workspace = s.workspaces[key] + } else { + matchedKeys := make([]string, 0) + for candidateKey, candidateWorkspace := range s.workspaces { + if candidateWorkspace.Name == name { + matchedKeys = append(matchedKeys, candidateKey) + } + } + if len(matchedKeys) > 1 { + return nil, false, apierrors.NewBadRequest( + fmt.Sprintf("workspace name %q is ambiguous across namespaces; specify namespace", name), + ) + } + if len(matchedKeys) == 1 { + key = matchedKeys[0] + workspace = s.workspaces[key] + } + } + + if workspace == nil { + return nil, false, apierrors.NewNotFound(aggregationv1alpha1.Resource("coderworkspaces"), name) + } + + if deleteValidation != nil { + if err := deleteValidation(ctx, workspace.DeepCopy()); err != nil { + return nil, false, err + } + } + + deleted := workspace.DeepCopy() + delete(s.workspaces, key) + return deleted, true, nil +} + // ConvertToTable converts a workspace object or list into kubectl table output. func (s *WorkspaceStorage) ConvertToTable(ctx context.Context, object, tableOptions runtime.Object) (*metav1.Table, error) { if s == nil { @@ -188,3 +469,33 @@ func (s *WorkspaceStorage) ConvertToTable(ctx context.Context, object, tableOpti return s.tableConvertor.ConvertToTable(ctx, object, tableOptions) } + +func ensureWorkspaceTypeMeta(workspace *aggregationv1alpha1.CoderWorkspace) { + if workspace == nil { + panic("assertion failed: workspace must not be nil") + } + workspace.TypeMeta = metav1.TypeMeta{ + Kind: "CoderWorkspace", + APIVersion: aggregationv1alpha1.SchemeGroupVersion.String(), + } +} + +func (s *WorkspaceStorage) findWorkspaceByNameLocked(name string) (*aggregationv1alpha1.CoderWorkspace, bool, bool) { + matchedKeys := make([]string, 0) + for key, workspace := range s.workspaces { + if workspace.Name == name { + matchedKeys = append(matchedKeys, key) + } + } + if len(matchedKeys) == 0 { + return nil, false, false + } + if len(matchedKeys) > 1 { + return nil, false, true + } + workspace := s.workspaces[matchedKeys[0]] + if workspace == nil { + return nil, false, false + } + return workspace, true, false +} diff --git a/internal/aggregated/storage/workspace_test.go b/internal/aggregated/storage/workspace_test.go new file mode 100644 index 00000000..b77cc98f --- /dev/null +++ b/internal/aggregated/storage/workspace_test.go @@ -0,0 +1,231 @@ +package storage + +import ( + "context" + "testing" + "time" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + genericapirequest "k8s.io/apiserver/pkg/endpoints/request" + "k8s.io/apiserver/pkg/registry/rest" + + aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" +) + +func TestWorkspaceStorageCRUDLifecycle(t *testing.T) { + t.Helper() + + workspaceStorage := NewWorkspaceStorage() + ctx := genericapirequest.WithNamespace(context.Background(), "default") + + createdObj, err := workspaceStorage.Create(ctx, &aggregationv1alpha1.CoderWorkspace{ + ObjectMeta: metav1.ObjectMeta{Name: "unit-workspace"}, + Spec: aggregationv1alpha1.CoderWorkspaceSpec{Running: true}, + }, rest.ValidateAllObjectFunc, nil) + if err != nil { + t.Fatalf("create workspace: %v", err) + } + + created, ok := createdObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace from create, got %T", createdObj) + } + if created.Namespace != "default" { + t.Fatalf("expected namespace default, got %q", created.Namespace) + } + if created.ResourceVersion != "1" { + t.Fatalf("expected resourceVersion 1, got %q", created.ResourceVersion) + } + if created.Generation != 1 { + t.Fatalf("expected generation 1, got %d", created.Generation) + } + + toUpdate := created.DeepCopy() + toUpdate.Spec.Running = false + toUpdate.ResourceVersion = created.ResourceVersion + updatedObj, createdOnUpdate, err := workspaceStorage.Update( + ctx, + toUpdate.Name, + rest.DefaultUpdatedObjectInfo(toUpdate), + nil, + nil, + false, + nil, + ) + if err != nil { + t.Fatalf("update workspace: %v", err) + } + if createdOnUpdate { + t.Fatal("expected update of existing workspace, got createdOnUpdate=true") + } + + updated, ok := updatedObj.(*aggregationv1alpha1.CoderWorkspace) + if !ok { + t.Fatalf("expected *CoderWorkspace from update, got %T", updatedObj) + } + if updated.Spec.Running { + t.Fatalf("expected running=false after update, got %+v", updated.Spec) + } + if updated.ResourceVersion == created.ResourceVersion { + t.Fatalf("expected resourceVersion to change, got %q", updated.ResourceVersion) + } + if updated.Generation != created.Generation+1 { + t.Fatalf("expected generation increment to %d, got %d", created.Generation+1, updated.Generation) + } + + deletedObj, deletedNow, err := workspaceStorage.Delete(ctx, created.Name, nil, nil) + if err != nil { + t.Fatalf("delete workspace: %v", err) + } + if !deletedNow { + t.Fatal("expected immediate delete") + } + if _, ok := deletedObj.(*aggregationv1alpha1.CoderWorkspace); !ok { + t.Fatalf("expected *CoderWorkspace from delete, got %T", deletedObj) + } + + _, err = workspaceStorage.Get(ctx, created.Name, nil) + if !apierrors.IsNotFound(err) { + t.Fatalf("expected NotFound after delete, got %v", err) + } +} + +func TestWorkspaceStorageCreateAlreadyExists(t *testing.T) { + t.Helper() + + workspaceStorage := NewWorkspaceStorage() + ctx := genericapirequest.WithNamespace(context.Background(), "default") + + _, err := workspaceStorage.Create(ctx, &aggregationv1alpha1.CoderWorkspace{ + ObjectMeta: metav1.ObjectMeta{Name: "dev-workspace"}, + Spec: aggregationv1alpha1.CoderWorkspaceSpec{Running: true}, + }, nil, nil) + if !apierrors.IsAlreadyExists(err) { + t.Fatalf("expected AlreadyExists error, got %v", err) + } +} + +func TestWorkspaceStorageUpdateRejectsNamespaceChange(t *testing.T) { + t.Helper() + + workspaceStorage := NewWorkspaceStorage() + ctx := genericapirequest.WithNamespace(context.Background(), "default") + + currentObj, err := workspaceStorage.Get(ctx, "dev-workspace", nil) + if err != nil { + t.Fatalf("get workspace: %v", err) + } + current := currentObj.(*aggregationv1alpha1.CoderWorkspace) + + modified := current.DeepCopy() + modified.Namespace = "sandbox" + modified.ResourceVersion = current.ResourceVersion + + _, _, err = workspaceStorage.Update( + ctx, + modified.Name, + rest.DefaultUpdatedObjectInfo(modified), + nil, + nil, + false, + nil, + ) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest for namespace mismatch, got %v", err) + } +} + +func TestWorkspaceStorageUpdateIgnoresStatusWrites(t *testing.T) { + t.Helper() + + workspaceStorage := NewWorkspaceStorage() + ctx := genericapirequest.WithNamespace(context.Background(), "default") + + currentObj, err := workspaceStorage.Get(ctx, "dev-workspace", nil) + if err != nil { + t.Fatalf("get workspace: %v", err) + } + current := currentObj.(*aggregationv1alpha1.CoderWorkspace) + if current.Status.AutoShutdown == nil { + t.Fatal("expected seeded workspace status autoShutdown") + } + + modified := current.DeepCopy() + modified.Spec.Running = !current.Spec.Running + modified.ResourceVersion = current.ResourceVersion + overrideDeadline := metav1.NewTime(time.Date(2040, time.January, 1, 0, 0, 0, 0, time.UTC)) + modified.Status.AutoShutdown = &overrideDeadline + + updatedObj, _, err := workspaceStorage.Update( + ctx, + modified.Name, + rest.DefaultUpdatedObjectInfo(modified), + nil, + nil, + false, + nil, + ) + if err != nil { + t.Fatalf("update workspace: %v", err) + } + + updated := updatedObj.(*aggregationv1alpha1.CoderWorkspace) + if updated.Status.AutoShutdown == nil { + t.Fatal("expected status autoShutdown to remain present") + } + if !updated.Status.AutoShutdown.Equal(current.Status.AutoShutdown) { + t.Fatalf("expected status to remain unchanged, got %s want %s", updated.Status.AutoShutdown, current.Status.AutoShutdown) + } +} + +func TestWorkspaceStorageDeleteAmbiguousWithoutNamespace(t *testing.T) { + t.Helper() + + workspaceStorage := NewWorkspaceStorage() + + _, err := workspaceStorage.Create( + genericapirequest.WithNamespace(context.Background(), "sandbox"), + &aggregationv1alpha1.CoderWorkspace{ObjectMeta: metav1.ObjectMeta{Name: "dev-workspace"}}, + nil, + nil, + ) + if err != nil { + t.Fatalf("seed same-name workspace in sandbox namespace: %v", err) + } + + _, _, err = workspaceStorage.Delete(context.Background(), "dev-workspace", nil, nil) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest for ambiguous delete, got %v", err) + } +} + +func TestWorkspaceStorageUpdateRequiresResourceVersion(t *testing.T) { + t.Helper() + + workspaceStorage := NewWorkspaceStorage() + ctx := genericapirequest.WithNamespace(context.Background(), "default") + + currentObj, err := workspaceStorage.Get(ctx, "dev-workspace", nil) + if err != nil { + t.Fatalf("get workspace: %v", err) + } + current := currentObj.(*aggregationv1alpha1.CoderWorkspace) + + modified := current.DeepCopy() + modified.Spec.Running = !current.Spec.Running + modified.ResourceVersion = "" + + _, _, err = workspaceStorage.Update( + ctx, + modified.Name, + rest.DefaultUpdatedObjectInfo(modified), + nil, + nil, + false, + nil, + ) + if !apierrors.IsBadRequest(err) { + t.Fatalf("expected BadRequest when resourceVersion is missing, got %v", err) + } +} diff --git a/internal/app/mcpapp/tools.go b/internal/app/mcpapp/tools.go index 72ac7d1a..e6c853ae 100644 --- a/internal/app/mcpapp/tools.go +++ b/internal/app/mcpapp/tools.go @@ -4,11 +4,14 @@ import ( "context" "fmt" "io" + "math" + "sort" "time" aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" coderv1alpha1 "github.com/coder/coder-k8s/api/v1alpha1" "github.com/modelcontextprotocol/go-sdk/mcp" + appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/fields" @@ -17,11 +20,13 @@ import ( ) const ( - defaultPodLogTailLines int64 = 2000 - maxPodLogBytes int64 = 1 << 20 - podLogTruncatedSuffix = "\n(truncated)" - defaultEventListLimit int64 = 200 - maxEventListLimit int64 = 1000 + defaultPodLogTailLines int64 = 2000 + maxPodLogBytes int64 = 1 << 20 + podLogTruncatedSuffix = "\n(truncated)" + defaultEventListLimit int64 = 200 + maxEventListLimit int64 = 1000 + defaultControlPlanePodListLimit int64 = 100 + maxControlPlanePodListLimit int64 = 500 ) type listControlPlanesInput struct { @@ -53,6 +58,75 @@ type getControlPlaneStatusOutput struct { Conditions []metav1.Condition `json:"conditions"` } +type listControlPlanePodsInput struct { + Namespace string `json:"namespace"` + Name string `json:"name"` + Limit int64 `json:"limit,omitempty"` +} + +type controlPlanePodSummary struct { + Name string `json:"name"` + Namespace string `json:"namespace"` + Phase string `json:"phase"` + NodeName string `json:"nodeName,omitempty"` + ReadyContainers int32 `json:"readyContainers"` + TotalContainers int32 `json:"totalContainers"` + StartTime string `json:"startTime,omitempty"` +} + +type listControlPlanePodsOutput struct { + Items []controlPlanePodSummary `json:"items"` + Truncated bool `json:"truncated,omitempty"` +} + +type getControlPlaneDeploymentStatusInput struct { + Namespace string `json:"namespace"` + Name string `json:"name"` +} + +type deploymentConditionSummary struct { + Type string `json:"type"` + Status string `json:"status"` + Reason string `json:"reason,omitempty"` + Message string `json:"message,omitempty"` + LastUpdateTime string `json:"lastUpdateTime,omitempty"` + LastTransitionTime string `json:"lastTransitionTime,omitempty"` +} + +type getControlPlaneDeploymentStatusOutput struct { + Name string `json:"name"` + Namespace string `json:"namespace"` + Replicas int32 `json:"replicas"` + ReadyReplicas int32 `json:"readyReplicas"` + UpdatedReplicas int32 `json:"updatedReplicas"` + AvailableReplicas int32 `json:"availableReplicas"` + UnavailableReplicas int32 `json:"unavailableReplicas"` + ObservedGeneration int64 `json:"observedGeneration"` + Conditions []deploymentConditionSummary `json:"conditions"` +} + +type getServiceStatusInput struct { + Namespace string `json:"namespace"` + Name string `json:"name"` +} + +type servicePortSummary struct { + Name string `json:"name,omitempty"` + Port int32 `json:"port"` + TargetPort string `json:"targetPort,omitempty"` + NodePort int32 `json:"nodePort,omitempty"` + Protocol string `json:"protocol,omitempty"` +} + +type getServiceStatusOutput struct { + Name string `json:"name"` + Namespace string `json:"namespace"` + Type string `json:"type"` + ClusterIP string `json:"clusterIP,omitempty"` + Ports []servicePortSummary `json:"ports"` + Annotations map[string]string `json:"annotations,omitempty"` +} + type listWorkspacesInput struct { Namespace string `json:"namespace,omitempty"` } @@ -68,6 +142,26 @@ type listWorkspacesOutput struct { Items []workspaceSummary `json:"items"` } +type getWorkspaceInput struct { + Namespace string `json:"namespace"` + Name string `json:"name"` +} + +type getWorkspaceOutput struct { + Workspace workspaceSummary `json:"workspace"` +} + +type setWorkspaceRunningInput struct { + Namespace string `json:"namespace"` + Name string `json:"name"` + Running bool `json:"running"` +} + +type setWorkspaceRunningOutput struct { + Workspace workspaceSummary `json:"workspace"` + Updated bool `json:"updated"` +} + type listTemplatesInput struct { Namespace string `json:"namespace,omitempty"` } @@ -83,6 +177,26 @@ type listTemplatesOutput struct { Items []templateSummary `json:"items"` } +type getTemplateInput struct { + Namespace string `json:"namespace"` + Name string `json:"name"` +} + +type getTemplateOutput struct { + Template templateSummary `json:"template"` +} + +type setTemplateRunningInput struct { + Namespace string `json:"namespace"` + Name string `json:"name"` + Running bool `json:"running"` +} + +type setTemplateRunningOutput struct { + Template templateSummary `json:"template"` + Updated bool `json:"updated"` +} + type getEventsInput struct { Namespace string `json:"namespace"` Name string `json:"name,omitempty"` @@ -183,6 +297,39 @@ func registerTools(server *mcp.Server, k8sClient client.Client, clientset kubern }, nil }) + mcp.AddTool(server, &mcp.Tool{ + Name: "list_control_plane_pods", + Description: "List pods for a CoderControlPlane by namespace and name.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, input listControlPlanePodsInput) (*mcp.CallToolResult, listControlPlanePodsOutput, error) { + output, err := listControlPlanePods(ctx, k8sClient, input) + if err != nil { + return nil, listControlPlanePodsOutput{}, err + } + return nil, output, nil + }) + + mcp.AddTool(server, &mcp.Tool{ + Name: "get_control_plane_deployment_status", + Description: "Get Deployment status for a CoderControlPlane by namespace and name.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, input getControlPlaneDeploymentStatusInput) (*mcp.CallToolResult, getControlPlaneDeploymentStatusOutput, error) { + output, err := getControlPlaneDeploymentStatus(ctx, k8sClient, input) + if err != nil { + return nil, getControlPlaneDeploymentStatusOutput{}, err + } + return nil, output, nil + }) + + mcp.AddTool(server, &mcp.Tool{ + Name: "get_service_status", + Description: "Get Kubernetes Service status by namespace and name.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, input getServiceStatusInput) (*mcp.CallToolResult, getServiceStatusOutput, error) { + output, err := getServiceStatus(ctx, k8sClient, input) + if err != nil { + return nil, getServiceStatusOutput{}, err + } + return nil, output, nil + }) + mcp.AddTool(server, &mcp.Tool{ Name: "list_workspaces", Description: "List CoderWorkspace resources.", @@ -198,12 +345,29 @@ func registerTools(server *mcp.Server, k8sClient client.Client, clientset kubern output := listWorkspacesOutput{Items: make([]workspaceSummary, 0, len(workspaceList.Items))} for _, workspace := range workspaceList.Items { - output.Items = append(output.Items, workspaceSummary{ - Name: workspace.Name, - Namespace: workspace.Namespace, - Running: workspace.Spec.Running, - AutoShutdown: formatOptionalTime(workspace.Status.AutoShutdown), - }) + output.Items = append(output.Items, workspaceToSummary(&workspace)) + } + return nil, output, nil + }) + + mcp.AddTool(server, &mcp.Tool{ + Name: "get_workspace", + Description: "Get a CoderWorkspace resource by namespace and name.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, input getWorkspaceInput) (*mcp.CallToolResult, getWorkspaceOutput, error) { + workspace, err := getWorkspaceDetail(ctx, k8sClient, input) + if err != nil { + return nil, getWorkspaceOutput{}, err + } + return nil, getWorkspaceOutput{Workspace: workspace}, nil + }) + + mcp.AddTool(server, &mcp.Tool{ + Name: "set_workspace_running", + Description: "Set spec.running for a CoderWorkspace resource.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, input setWorkspaceRunningInput) (*mcp.CallToolResult, setWorkspaceRunningOutput, error) { + output, err := setWorkspaceRunning(ctx, k8sClient, input) + if err != nil { + return nil, setWorkspaceRunningOutput{}, err } return nil, output, nil }) @@ -223,12 +387,29 @@ func registerTools(server *mcp.Server, k8sClient client.Client, clientset kubern output := listTemplatesOutput{Items: make([]templateSummary, 0, len(templateList.Items))} for _, template := range templateList.Items { - output.Items = append(output.Items, templateSummary{ - Name: template.Name, - Namespace: template.Namespace, - Running: template.Spec.Running, - AutoShutdown: formatOptionalTime(template.Status.AutoShutdown), - }) + output.Items = append(output.Items, templateToSummary(&template)) + } + return nil, output, nil + }) + + mcp.AddTool(server, &mcp.Tool{ + Name: "get_template", + Description: "Get a CoderTemplate resource by namespace and name.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, input getTemplateInput) (*mcp.CallToolResult, getTemplateOutput, error) { + template, err := getTemplateDetail(ctx, k8sClient, input) + if err != nil { + return nil, getTemplateOutput{}, err + } + return nil, getTemplateOutput{Template: template}, nil + }) + + mcp.AddTool(server, &mcp.Tool{ + Name: "set_template_running", + Description: "Set spec.running for a CoderTemplate resource.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, input setTemplateRunningInput) (*mcp.CallToolResult, setTemplateRunningOutput, error) { + output, err := setTemplateRunning(ctx, k8sClient, input) + if err != nil { + return nil, setTemplateRunningOutput{}, err } return nil, output, nil }) @@ -352,6 +533,385 @@ func registerTools(server *mcp.Server, k8sClient client.Client, clientset kubern }) } +func listControlPlanePods(ctx context.Context, k8sClient client.Client, input listControlPlanePodsInput) (listControlPlanePodsOutput, error) { + if k8sClient == nil { + return listControlPlanePodsOutput{}, fmt.Errorf("assertion failed: Kubernetes client must not be nil") + } + if input.Namespace == "" { + return listControlPlanePodsOutput{}, fmt.Errorf("namespace is required") + } + if input.Name == "" { + return listControlPlanePodsOutput{}, fmt.Errorf("name is required") + } + + labels := controlPlaneWorkloadLabels(input.Name) + if len(labels) == 0 { + return listControlPlanePodsOutput{}, fmt.Errorf("assertion failed: control plane workload labels must not be empty") + } + + podList := &corev1.PodList{} + if err := k8sClient.List(ctx, podList, client.InNamespace(input.Namespace), client.MatchingLabels(labels)); err != nil { + return listControlPlanePodsOutput{}, fmt.Errorf("list control plane pods for %s/%s: %w", input.Namespace, input.Name, err) + } + + sort.Slice(podList.Items, func(i, j int) bool { + if podList.Items[i].Name == podList.Items[j].Name { + return podList.Items[i].Namespace < podList.Items[j].Namespace + } + return podList.Items[i].Name < podList.Items[j].Name + }) + + limit := sanitizeControlPlanePodListLimit(input.Limit) + if limit <= 0 { + return listControlPlanePodsOutput{}, fmt.Errorf("assertion failed: pod list limit must be positive") + } + + output := listControlPlanePodsOutput{Items: make([]controlPlanePodSummary, 0, minInt(len(podList.Items), int(limit)))} + for i, pod := range podList.Items { + if int64(i) >= limit { + output.Truncated = true + break + } + + readyContainers := int32(0) + for _, status := range pod.Status.ContainerStatuses { + if status.Ready { + readyContainers++ + } + } + + totalContainers, err := safeInt32Count(len(pod.Spec.Containers)) + if err != nil { + return listControlPlanePodsOutput{}, err + } + if totalContainers == 0 { + totalContainers, err = safeInt32Count(len(pod.Status.ContainerStatuses)) + if err != nil { + return listControlPlanePodsOutput{}, err + } + } + + output.Items = append(output.Items, controlPlanePodSummary{ + Name: pod.Name, + Namespace: pod.Namespace, + Phase: string(pod.Status.Phase), + NodeName: pod.Spec.NodeName, + ReadyContainers: readyContainers, + TotalContainers: totalContainers, + StartTime: formatOptionalTime(pod.Status.StartTime), + }) + } + + return output, nil +} + +func getControlPlaneDeploymentStatus( + ctx context.Context, + k8sClient client.Client, + input getControlPlaneDeploymentStatusInput, +) (getControlPlaneDeploymentStatusOutput, error) { + if k8sClient == nil { + return getControlPlaneDeploymentStatusOutput{}, fmt.Errorf("assertion failed: Kubernetes client must not be nil") + } + if input.Namespace == "" { + return getControlPlaneDeploymentStatusOutput{}, fmt.Errorf("namespace is required") + } + if input.Name == "" { + return getControlPlaneDeploymentStatusOutput{}, fmt.Errorf("name is required") + } + + deployment := &appsv1.Deployment{} + if err := k8sClient.Get(ctx, client.ObjectKey{Namespace: input.Namespace, Name: input.Name}, deployment); err != nil { + return getControlPlaneDeploymentStatusOutput{}, fmt.Errorf("get deployment %s/%s: %w", input.Namespace, input.Name, err) + } + + conditions := make([]deploymentConditionSummary, 0, len(deployment.Status.Conditions)) + for _, condition := range deployment.Status.Conditions { + conditions = append(conditions, deploymentConditionSummary{ + Type: string(condition.Type), + Status: string(condition.Status), + Reason: condition.Reason, + Message: condition.Message, + LastUpdateTime: formatTime(condition.LastUpdateTime), + LastTransitionTime: formatTime(condition.LastTransitionTime), + }) + } + sort.Slice(conditions, func(i, j int) bool { + return conditions[i].Type < conditions[j].Type + }) + + return getControlPlaneDeploymentStatusOutput{ + Name: deployment.Name, + Namespace: deployment.Namespace, + Replicas: desiredReplicas(deployment.Spec.Replicas), + ReadyReplicas: deployment.Status.ReadyReplicas, + UpdatedReplicas: deployment.Status.UpdatedReplicas, + AvailableReplicas: deployment.Status.AvailableReplicas, + UnavailableReplicas: deployment.Status.UnavailableReplicas, + ObservedGeneration: deployment.Status.ObservedGeneration, + Conditions: conditions, + }, nil +} + +func getServiceStatus(ctx context.Context, k8sClient client.Client, input getServiceStatusInput) (getServiceStatusOutput, error) { + if k8sClient == nil { + return getServiceStatusOutput{}, fmt.Errorf("assertion failed: Kubernetes client must not be nil") + } + if input.Namespace == "" { + return getServiceStatusOutput{}, fmt.Errorf("namespace is required") + } + if input.Name == "" { + return getServiceStatusOutput{}, fmt.Errorf("name is required") + } + + service := &corev1.Service{} + if err := k8sClient.Get(ctx, client.ObjectKey{Namespace: input.Namespace, Name: input.Name}, service); err != nil { + return getServiceStatusOutput{}, fmt.Errorf("get service %s/%s: %w", input.Namespace, input.Name, err) + } + + ports := make([]servicePortSummary, 0, len(service.Spec.Ports)) + for _, port := range service.Spec.Ports { + ports = append(ports, servicePortSummary{ + Name: port.Name, + Port: port.Port, + TargetPort: port.TargetPort.String(), + NodePort: port.NodePort, + Protocol: string(port.Protocol), + }) + } + + return getServiceStatusOutput{ + Name: service.Name, + Namespace: service.Namespace, + Type: string(service.Spec.Type), + ClusterIP: service.Spec.ClusterIP, + Ports: ports, + Annotations: cloneStringMap(service.Annotations), + }, nil +} + +func getWorkspaceDetail(ctx context.Context, k8sClient client.Client, input getWorkspaceInput) (workspaceSummary, error) { + if k8sClient == nil { + return workspaceSummary{}, fmt.Errorf("assertion failed: Kubernetes client must not be nil") + } + if input.Namespace == "" { + return workspaceSummary{}, fmt.Errorf("namespace is required") + } + if input.Name == "" { + return workspaceSummary{}, fmt.Errorf("name is required") + } + + workspace := &aggregationv1alpha1.CoderWorkspace{} + if err := k8sClient.Get(ctx, client.ObjectKey{Namespace: input.Namespace, Name: input.Name}, workspace); err != nil { + return workspaceSummary{}, fmt.Errorf("get CoderWorkspace %s/%s: %w", input.Namespace, input.Name, err) + } + if workspace.Namespace != input.Namespace || workspace.Name != input.Name { + return workspaceSummary{}, fmt.Errorf( + "assertion failed: fetched workspace %s/%s does not match request %s/%s", + workspace.Namespace, + workspace.Name, + input.Namespace, + input.Name, + ) + } + + return workspaceToSummary(workspace), nil +} + +func setWorkspaceRunning( + ctx context.Context, + k8sClient client.Client, + input setWorkspaceRunningInput, +) (setWorkspaceRunningOutput, error) { + if k8sClient == nil { + return setWorkspaceRunningOutput{}, fmt.Errorf("assertion failed: Kubernetes client must not be nil") + } + if input.Namespace == "" { + return setWorkspaceRunningOutput{}, fmt.Errorf("namespace is required") + } + if input.Name == "" { + return setWorkspaceRunningOutput{}, fmt.Errorf("name is required") + } + + workspace := &aggregationv1alpha1.CoderWorkspace{} + if err := k8sClient.Get(ctx, client.ObjectKey{Namespace: input.Namespace, Name: input.Name}, workspace); err != nil { + return setWorkspaceRunningOutput{}, fmt.Errorf("get CoderWorkspace %s/%s: %w", input.Namespace, input.Name, err) + } + if workspace.Namespace != input.Namespace || workspace.Name != input.Name { + return setWorkspaceRunningOutput{}, fmt.Errorf( + "assertion failed: fetched workspace %s/%s does not match request %s/%s", + workspace.Namespace, + workspace.Name, + input.Namespace, + input.Name, + ) + } + + updated := workspace.Spec.Running != input.Running + workspace.Spec.Running = input.Running + if updated { + if err := k8sClient.Update(ctx, workspace); err != nil { + return setWorkspaceRunningOutput{}, fmt.Errorf("update CoderWorkspace %s/%s: %w", input.Namespace, input.Name, err) + } + } + + return setWorkspaceRunningOutput{ + Workspace: workspaceToSummary(workspace), + Updated: updated, + }, nil +} + +func getTemplateDetail(ctx context.Context, k8sClient client.Client, input getTemplateInput) (templateSummary, error) { + if k8sClient == nil { + return templateSummary{}, fmt.Errorf("assertion failed: Kubernetes client must not be nil") + } + if input.Namespace == "" { + return templateSummary{}, fmt.Errorf("namespace is required") + } + if input.Name == "" { + return templateSummary{}, fmt.Errorf("name is required") + } + + template := &aggregationv1alpha1.CoderTemplate{} + if err := k8sClient.Get(ctx, client.ObjectKey{Namespace: input.Namespace, Name: input.Name}, template); err != nil { + return templateSummary{}, fmt.Errorf("get CoderTemplate %s/%s: %w", input.Namespace, input.Name, err) + } + if template.Namespace != input.Namespace || template.Name != input.Name { + return templateSummary{}, fmt.Errorf( + "assertion failed: fetched template %s/%s does not match request %s/%s", + template.Namespace, + template.Name, + input.Namespace, + input.Name, + ) + } + + return templateToSummary(template), nil +} + +func setTemplateRunning( + ctx context.Context, + k8sClient client.Client, + input setTemplateRunningInput, +) (setTemplateRunningOutput, error) { + if k8sClient == nil { + return setTemplateRunningOutput{}, fmt.Errorf("assertion failed: Kubernetes client must not be nil") + } + if input.Namespace == "" { + return setTemplateRunningOutput{}, fmt.Errorf("namespace is required") + } + if input.Name == "" { + return setTemplateRunningOutput{}, fmt.Errorf("name is required") + } + + template := &aggregationv1alpha1.CoderTemplate{} + if err := k8sClient.Get(ctx, client.ObjectKey{Namespace: input.Namespace, Name: input.Name}, template); err != nil { + return setTemplateRunningOutput{}, fmt.Errorf("get CoderTemplate %s/%s: %w", input.Namespace, input.Name, err) + } + if template.Namespace != input.Namespace || template.Name != input.Name { + return setTemplateRunningOutput{}, fmt.Errorf( + "assertion failed: fetched template %s/%s does not match request %s/%s", + template.Namespace, + template.Name, + input.Namespace, + input.Name, + ) + } + + updated := template.Spec.Running != input.Running + template.Spec.Running = input.Running + if updated { + if err := k8sClient.Update(ctx, template); err != nil { + return setTemplateRunningOutput{}, fmt.Errorf("update CoderTemplate %s/%s: %w", input.Namespace, input.Name, err) + } + } + + return setTemplateRunningOutput{ + Template: templateToSummary(template), + Updated: updated, + }, nil +} + +func controlPlaneWorkloadLabels(name string) map[string]string { + return map[string]string{ + "app.kubernetes.io/name": "coder-control-plane", + "app.kubernetes.io/instance": name, + "app.kubernetes.io/managed-by": "coder-k8s", + } +} + +func workspaceToSummary(workspace *aggregationv1alpha1.CoderWorkspace) workspaceSummary { + if workspace == nil { + panic("assertion failed: workspace must not be nil") + } + + return workspaceSummary{ + Name: workspace.Name, + Namespace: workspace.Namespace, + Running: workspace.Spec.Running, + AutoShutdown: formatOptionalTime(workspace.Status.AutoShutdown), + } +} + +func templateToSummary(template *aggregationv1alpha1.CoderTemplate) templateSummary { + if template == nil { + panic("assertion failed: template must not be nil") + } + + return templateSummary{ + Name: template.Name, + Namespace: template.Namespace, + Running: template.Spec.Running, + AutoShutdown: formatOptionalTime(template.Status.AutoShutdown), + } +} + +func desiredReplicas(replicas *int32) int32 { + if replicas == nil { + return 0 + } + return *replicas +} + +func sanitizeControlPlanePodListLimit(limit int64) int64 { + if limit <= 0 { + return defaultControlPlanePodListLimit + } + if limit > maxControlPlanePodListLimit { + return maxControlPlanePodListLimit + } + return limit +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} + +func safeInt32Count(value int) (int32, error) { + if value < 0 { + return 0, fmt.Errorf("assertion failed: count must not be negative") + } + if value > math.MaxInt32 { + return 0, fmt.Errorf("assertion failed: count %d exceeds int32 max %d", value, math.MaxInt32) + } + return int32(value), nil +} + +func cloneStringMap(source map[string]string) map[string]string { + if len(source) == 0 { + return nil + } + + cloned := make(map[string]string, len(source)) + for key, value := range source { + cloned[key] = value + } + return cloned +} + func formatOptionalTime(value *metav1.Time) string { if value == nil || value.IsZero() { return "" @@ -359,6 +919,13 @@ func formatOptionalTime(value *metav1.Time) string { return value.UTC().Format(time.RFC3339) } +func formatTime(value metav1.Time) string { + if value.IsZero() { + return "" + } + return value.UTC().Format(time.RFC3339) +} + func eventTimestamp(event corev1.Event) string { if !event.EventTime.IsZero() { return event.EventTime.Time.UTC().Format(time.RFC3339) diff --git a/internal/app/mcpapp/tools_test.go b/internal/app/mcpapp/tools_test.go new file mode 100644 index 00000000..beb6317f --- /dev/null +++ b/internal/app/mcpapp/tools_test.go @@ -0,0 +1,552 @@ +package mcpapp + +import ( + "context" + "fmt" + "sort" + "testing" + "time" + + aggregationv1alpha1 "github.com/coder/coder-k8s/api/aggregation/v1alpha1" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + apiMeta "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/intstr" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +func TestListControlPlanePods(t *testing.T) { + t.Helper() + + labelsForAlpha := controlPlaneWorkloadLabels("alpha") + startTime := metav1.NewTime(time.Date(2026, time.January, 15, 12, 0, 0, 0, time.UTC)) + + k8sClient := mustNewFakeClient( + t, + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "alpha-1", Namespace: "default", Labels: labelsForAlpha}, + Spec: corev1.PodSpec{ + NodeName: "node-1", + Containers: []corev1.Container{{Name: "main"}, {Name: "sidecar"}}, + }, + Status: corev1.PodStatus{ + Phase: corev1.PodRunning, + StartTime: &startTime, + ContainerStatuses: []corev1.ContainerStatus{{Name: "main", Ready: true}, {Name: "sidecar", Ready: false}}, + }, + }, + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "alpha-2", Namespace: "default", Labels: labelsForAlpha}, + Spec: corev1.PodSpec{Containers: []corev1.Container{{Name: "main"}}}, + Status: corev1.PodStatus{ + Phase: corev1.PodPending, + ContainerStatuses: []corev1.ContainerStatus{{Name: "main", Ready: false}}, + }, + }, + &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "other", Namespace: "default", Labels: controlPlaneWorkloadLabels("beta")}, + Spec: corev1.PodSpec{Containers: []corev1.Container{{Name: "main"}}}, + }, + ) + + output, err := listControlPlanePods(context.Background(), k8sClient, listControlPlanePodsInput{ + Namespace: "default", + Name: "alpha", + Limit: 1, + }) + if err != nil { + t.Fatalf("list control plane pods: %v", err) + } + if len(output.Items) != 1 { + t.Fatalf("expected one item due to limit, got %d", len(output.Items)) + } + if !output.Truncated { + t.Fatal("expected truncated output when limit is smaller than matching pods") + } + + pod := output.Items[0] + if pod.Name != "alpha-1" { + t.Fatalf("expected sorted first pod alpha-1, got %q", pod.Name) + } + if pod.ReadyContainers != 1 || pod.TotalContainers != 2 { + t.Fatalf("expected readiness summary 1/2, got %d/%d", pod.ReadyContainers, pod.TotalContainers) + } + if pod.StartTime == "" { + t.Fatal("expected start time to be populated") + } +} + +func TestGetControlPlaneDeploymentStatus(t *testing.T) { + t.Helper() + + replicas := int32(3) + deployment := &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{Name: "alpha", Namespace: "default"}, + Spec: appsv1.DeploymentSpec{Replicas: &replicas}, + Status: appsv1.DeploymentStatus{ + ReadyReplicas: 2, + UpdatedReplicas: 3, + AvailableReplicas: 2, + UnavailableReplicas: 1, + ObservedGeneration: 7, + Conditions: []appsv1.DeploymentCondition{ + {Type: appsv1.DeploymentAvailable, Status: corev1.ConditionTrue}, + {Type: appsv1.DeploymentProgressing, Status: corev1.ConditionTrue}, + }, + }, + } + + k8sClient := mustNewFakeClient(t, deployment) + + output, err := getControlPlaneDeploymentStatus(context.Background(), k8sClient, getControlPlaneDeploymentStatusInput{ + Namespace: "default", + Name: "alpha", + }) + if err != nil { + t.Fatalf("get deployment status: %v", err) + } + + if output.Replicas != 3 || output.ReadyReplicas != 2 || output.ObservedGeneration != 7 { + t.Fatalf("unexpected deployment summary: %+v", output) + } + if len(output.Conditions) != 2 { + t.Fatalf("expected two deployment conditions, got %d", len(output.Conditions)) + } + if output.Conditions[0].Type != string(appsv1.DeploymentAvailable) { + t.Fatalf("expected sorted conditions with %q first, got %+v", appsv1.DeploymentAvailable, output.Conditions) + } +} + +func TestGetServiceStatus(t *testing.T) { + t.Helper() + + service := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "alpha", + Namespace: "default", + Annotations: map[string]string{ + "service.beta.kubernetes.io/aws-load-balancer-type": "nlb", + }, + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeLoadBalancer, + ClusterIP: "10.96.0.42", + Ports: []corev1.ServicePort{{ + Name: "http", + Port: 80, + TargetPort: intstr.FromInt(3000), + Protocol: corev1.ProtocolTCP, + }}, + }, + } + + k8sClient := mustNewFakeClient(t, service) + + output, err := getServiceStatus(context.Background(), k8sClient, getServiceStatusInput{ + Namespace: "default", + Name: "alpha", + }) + if err != nil { + t.Fatalf("get service status: %v", err) + } + + if output.Type != string(corev1.ServiceTypeLoadBalancer) { + t.Fatalf("expected service type load balancer, got %q", output.Type) + } + if len(output.Ports) != 1 { + t.Fatalf("expected one service port, got %d", len(output.Ports)) + } + if output.Ports[0].TargetPort != "3000" { + t.Fatalf("expected target port 3000, got %q", output.Ports[0].TargetPort) + } + if output.Annotations["service.beta.kubernetes.io/aws-load-balancer-type"] != "nlb" { + t.Fatalf("expected annotation copy, got %+v", output.Annotations) + } +} + +func TestWorkspaceToolHelpers(t *testing.T) { + t.Helper() + + deadline := metav1.NewTime(time.Date(2026, time.March, 1, 8, 0, 0, 0, time.UTC)) + workspace := &aggregationv1alpha1.CoderWorkspace{ + ObjectMeta: metav1.ObjectMeta{Name: "dev", Namespace: "default"}, + Spec: aggregationv1alpha1.CoderWorkspaceSpec{Running: false}, + Status: aggregationv1alpha1.CoderWorkspaceStatus{AutoShutdown: &deadline}, + } + + k8sClient := mustNewFakeClient(t, workspace) + ctx := context.Background() + + detail, err := getWorkspaceDetail(ctx, k8sClient, getWorkspaceInput{Namespace: "default", Name: "dev"}) + if err != nil { + t.Fatalf("get workspace detail: %v", err) + } + if detail.Running { + t.Fatalf("expected workspace to be stopped, got %+v", detail) + } + + setOutput, err := setWorkspaceRunning(ctx, k8sClient, setWorkspaceRunningInput{Namespace: "default", Name: "dev", Running: true}) + if err != nil { + t.Fatalf("set workspace running: %v", err) + } + if !setOutput.Updated || !setOutput.Workspace.Running { + t.Fatalf("expected workspace update to running=true, got %+v", setOutput) + } + + persisted := &aggregationv1alpha1.CoderWorkspace{} + if err := k8sClient.Get(ctx, types.NamespacedName{Namespace: "default", Name: "dev"}, persisted); err != nil { + t.Fatalf("get persisted workspace: %v", err) + } + if !persisted.Spec.Running { + t.Fatalf("expected persisted running=true, got %+v", persisted.Spec) + } + + noOpOutput, err := setWorkspaceRunning(ctx, k8sClient, setWorkspaceRunningInput{Namespace: "default", Name: "dev", Running: true}) + if err != nil { + t.Fatalf("set workspace running no-op: %v", err) + } + if noOpOutput.Updated { + t.Fatalf("expected no-op update to report Updated=false, got %+v", noOpOutput) + } +} + +func TestTemplateToolHelpers(t *testing.T) { + t.Helper() + + template := &aggregationv1alpha1.CoderTemplate{ + ObjectMeta: metav1.ObjectMeta{Name: "starter", Namespace: "default"}, + Spec: aggregationv1alpha1.CoderTemplateSpec{Running: true}, + } + + k8sClient := mustNewFakeClient(t, template) + ctx := context.Background() + + detail, err := getTemplateDetail(ctx, k8sClient, getTemplateInput{Namespace: "default", Name: "starter"}) + if err != nil { + t.Fatalf("get template detail: %v", err) + } + if !detail.Running { + t.Fatalf("expected template to be running, got %+v", detail) + } + + setOutput, err := setTemplateRunning(ctx, k8sClient, setTemplateRunningInput{Namespace: "default", Name: "starter", Running: false}) + if err != nil { + t.Fatalf("set template running: %v", err) + } + if !setOutput.Updated || setOutput.Template.Running { + t.Fatalf("expected template update to running=false, got %+v", setOutput) + } + + persisted := &aggregationv1alpha1.CoderTemplate{} + if err := k8sClient.Get(ctx, types.NamespacedName{Namespace: "default", Name: "starter"}, persisted); err != nil { + t.Fatalf("get persisted template: %v", err) + } + if persisted.Spec.Running { + t.Fatalf("expected persisted running=false, got %+v", persisted.Spec) + } +} + +func TestListControlPlanePodsInputValidation(t *testing.T) { + t.Helper() + + k8sClient := mustNewFakeClient(t) + + _, err := listControlPlanePods(context.Background(), k8sClient, listControlPlanePodsInput{Name: "alpha"}) + if err == nil { + t.Fatal("expected error when namespace is missing") + } + + _, err = listControlPlanePods(context.Background(), k8sClient, listControlPlanePodsInput{Namespace: "default"}) + if err == nil { + t.Fatal("expected error when name is missing") + } +} + +func mustNewFakeClient(t *testing.T, objects ...client.Object) client.Client { + t.Helper() + + scheme := newScheme() + if scheme == nil { + t.Fatal("expected non-nil scheme") + } + + stub := &stubClient{ + scheme: scheme, + pods: map[types.NamespacedName]*corev1.Pod{}, + deployments: map[types.NamespacedName]*appsv1.Deployment{}, + services: map[types.NamespacedName]*corev1.Service{}, + workspaces: map[types.NamespacedName]*aggregationv1alpha1.CoderWorkspace{}, + templates: map[types.NamespacedName]*aggregationv1alpha1.CoderTemplate{}, + } + + for _, object := range objects { + if object == nil { + continue + } + key := types.NamespacedName{Namespace: object.GetNamespace(), Name: object.GetName()} + switch typed := object.(type) { + case *corev1.Pod: + stub.pods[key] = typed.DeepCopy() + case *appsv1.Deployment: + stub.deployments[key] = typed.DeepCopy() + case *corev1.Service: + stub.services[key] = typed.DeepCopy() + case *aggregationv1alpha1.CoderWorkspace: + stub.workspaces[key] = typed.DeepCopy() + case *aggregationv1alpha1.CoderTemplate: + stub.templates[key] = typed.DeepCopy() + default: + t.Fatalf("unsupported object type for stub client: %T", object) + } + } + + return stub +} + +type stubClient struct { + scheme *runtime.Scheme + pods map[types.NamespacedName]*corev1.Pod + deployments map[types.NamespacedName]*appsv1.Deployment + services map[types.NamespacedName]*corev1.Service + workspaces map[types.NamespacedName]*aggregationv1alpha1.CoderWorkspace + templates map[types.NamespacedName]*aggregationv1alpha1.CoderTemplate +} + +func (s *stubClient) Get(_ context.Context, key client.ObjectKey, obj client.Object, _ ...client.GetOption) error { + if obj == nil { + return fmt.Errorf("assertion failed: object must not be nil") + } + + namespacedName := types.NamespacedName(key) + switch typed := obj.(type) { + case *corev1.Pod: + stored, ok := s.pods[namespacedName] + if !ok { + return newCoreNotFound("pods", key.Name) + } + *typed = *stored.DeepCopy() + return nil + case *appsv1.Deployment: + stored, ok := s.deployments[namespacedName] + if !ok { + return newAppsNotFound("deployments", key.Name) + } + *typed = *stored.DeepCopy() + return nil + case *corev1.Service: + stored, ok := s.services[namespacedName] + if !ok { + return newCoreNotFound("services", key.Name) + } + *typed = *stored.DeepCopy() + return nil + case *aggregationv1alpha1.CoderWorkspace: + stored, ok := s.workspaces[namespacedName] + if !ok { + return apierrors.NewNotFound(aggregationv1alpha1.Resource("coderworkspaces"), key.Name) + } + *typed = *stored.DeepCopy() + return nil + case *aggregationv1alpha1.CoderTemplate: + stored, ok := s.templates[namespacedName] + if !ok { + return apierrors.NewNotFound(aggregationv1alpha1.Resource("codertemplates"), key.Name) + } + *typed = *stored.DeepCopy() + return nil + default: + return fmt.Errorf("assertion failed: unsupported object type %T", obj) + } +} + +func (s *stubClient) List(_ context.Context, list client.ObjectList, opts ...client.ListOption) error { + if list == nil { + return fmt.Errorf("assertion failed: list must not be nil") + } + + listOptions := (&client.ListOptions{}).ApplyOptions(opts) + selector := listOptions.LabelSelector + + switch typed := list.(type) { + case *corev1.PodList: + items := make([]corev1.Pod, 0, len(s.pods)) + for _, pod := range s.pods { + if listOptions.Namespace != "" && pod.Namespace != listOptions.Namespace { + continue + } + if selector != nil && !selector.Matches(labels.Set(pod.Labels)) { + continue + } + items = append(items, *pod.DeepCopy()) + } + sort.Slice(items, func(i, j int) bool { + if items[i].Name == items[j].Name { + return items[i].Namespace < items[j].Namespace + } + return items[i].Name < items[j].Name + }) + typed.Items = items + return nil + case *aggregationv1alpha1.CoderWorkspaceList: + items := make([]aggregationv1alpha1.CoderWorkspace, 0, len(s.workspaces)) + for _, workspace := range s.workspaces { + if listOptions.Namespace != "" && workspace.Namespace != listOptions.Namespace { + continue + } + items = append(items, *workspace.DeepCopy()) + } + sort.Slice(items, func(i, j int) bool { + if items[i].Name == items[j].Name { + return items[i].Namespace < items[j].Namespace + } + return items[i].Name < items[j].Name + }) + typed.Items = items + return nil + case *aggregationv1alpha1.CoderTemplateList: + items := make([]aggregationv1alpha1.CoderTemplate, 0, len(s.templates)) + for _, template := range s.templates { + if listOptions.Namespace != "" && template.Namespace != listOptions.Namespace { + continue + } + items = append(items, *template.DeepCopy()) + } + sort.Slice(items, func(i, j int) bool { + if items[i].Name == items[j].Name { + return items[i].Namespace < items[j].Namespace + } + return items[i].Name < items[j].Name + }) + typed.Items = items + return nil + default: + return fmt.Errorf("assertion failed: unsupported list type %T", list) + } +} + +func (s *stubClient) Apply(_ context.Context, _ runtime.ApplyConfiguration, _ ...client.ApplyOption) error { + return fmt.Errorf("assertion failed: Apply is not implemented in stub client") +} + +func (s *stubClient) Create(_ context.Context, _ client.Object, _ ...client.CreateOption) error { + return fmt.Errorf("assertion failed: Create is not implemented in stub client") +} + +func (s *stubClient) Delete(_ context.Context, _ client.Object, _ ...client.DeleteOption) error { + return fmt.Errorf("assertion failed: Delete is not implemented in stub client") +} + +func (s *stubClient) Update(_ context.Context, obj client.Object, _ ...client.UpdateOption) error { + if obj == nil { + return fmt.Errorf("assertion failed: object must not be nil") + } + + key := types.NamespacedName{Namespace: obj.GetNamespace(), Name: obj.GetName()} + switch typed := obj.(type) { + case *aggregationv1alpha1.CoderWorkspace: + if _, exists := s.workspaces[key]; !exists { + return apierrors.NewNotFound(aggregationv1alpha1.Resource("coderworkspaces"), obj.GetName()) + } + s.workspaces[key] = typed.DeepCopy() + return nil + case *aggregationv1alpha1.CoderTemplate: + if _, exists := s.templates[key]; !exists { + return apierrors.NewNotFound(aggregationv1alpha1.Resource("codertemplates"), obj.GetName()) + } + s.templates[key] = typed.DeepCopy() + return nil + default: + return fmt.Errorf("assertion failed: unsupported update type %T", obj) + } +} + +func (s *stubClient) Patch(_ context.Context, _ client.Object, _ client.Patch, _ ...client.PatchOption) error { + return fmt.Errorf("assertion failed: Patch is not implemented in stub client") +} + +func (s *stubClient) DeleteAllOf(_ context.Context, _ client.Object, _ ...client.DeleteAllOfOption) error { + return fmt.Errorf("assertion failed: DeleteAllOf is not implemented in stub client") +} + +func (s *stubClient) Status() client.SubResourceWriter { + return stubSubResourceClient{} +} + +func (s *stubClient) SubResource(_ string) client.SubResourceClient { + return stubSubResourceClient{} +} + +func (s *stubClient) Scheme() *runtime.Scheme { + return s.scheme +} + +func (s *stubClient) RESTMapper() apiMeta.RESTMapper { + return nil +} + +func (s *stubClient) GroupVersionKindFor(_ runtime.Object) (schema.GroupVersionKind, error) { + return schema.GroupVersionKind{}, nil +} + +func (s *stubClient) IsObjectNamespaced(_ runtime.Object) (bool, error) { + return true, nil +} + +type stubSubResourceClient struct{} + +func (stubSubResourceClient) Get( + _ context.Context, + _ client.Object, + _ client.Object, + _ ...client.SubResourceGetOption, +) error { + return fmt.Errorf("assertion failed: subresource Get is not implemented in stub client") +} + +func (stubSubResourceClient) Create( + _ context.Context, + _ client.Object, + _ client.Object, + _ ...client.SubResourceCreateOption, +) error { + return fmt.Errorf("assertion failed: subresource Create is not implemented in stub client") +} + +func (stubSubResourceClient) Update( + _ context.Context, + _ client.Object, + _ ...client.SubResourceUpdateOption, +) error { + return fmt.Errorf("assertion failed: subresource Update is not implemented in stub client") +} + +func (stubSubResourceClient) Patch( + _ context.Context, + _ client.Object, + _ client.Patch, + _ ...client.SubResourcePatchOption, +) error { + return fmt.Errorf("assertion failed: subresource Patch is not implemented in stub client") +} + +func (stubSubResourceClient) Apply( + _ context.Context, + _ runtime.ApplyConfiguration, + _ ...client.SubResourceApplyOption, +) error { + return fmt.Errorf("assertion failed: subresource Apply is not implemented in stub client") +} + +func newCoreNotFound(resource, name string) error { + return apierrors.NewNotFound(schema.GroupResource{Group: "", Resource: resource}, name) +} + +func newAppsNotFound(resource, name string) error { + return apierrors.NewNotFound(schema.GroupResource{Group: "apps", Resource: resource}, name) +}