diff --git a/apis/druid/v1alpha1/druid_types.go b/apis/druid/v1alpha1/druid_types.go index 34c68a1..3413280 100644 --- a/apis/druid/v1alpha1/druid_types.go +++ b/apis/druid/v1alpha1/druid_types.go @@ -93,6 +93,22 @@ type AdditionalContainer struct { EnvFrom []v1.EnvFromSource `json:"envFrom,omitempty"` } +// MiddleManagerDrainStrategy configures operator-managed draining before a +// MiddleManager StatefulSet pod is rolled to a new revision. +type MiddleManagerDrainStrategy struct { + // DrainTimeout is the maximum time to wait for streaming ingestion tasks to + // drain before allowing Kubernetes to replace the MiddleManager pod. + // +optional + // +kubebuilder:default:="1h" + DrainTimeout metav1.Duration `json:"drainTimeout,omitempty"` + + // PodReadyTimeout is the maximum time to wait for Kubernetes to replace the + // pod and for the replacement to become ready on the target StatefulSet revision. + // +optional + // +kubebuilder:default:="30m" + PodReadyTimeout metav1.Duration `json:"podReadyTimeout,omitempty"` +} + // DruidSpec defines the desired state of the Druid cluster. type DruidSpec struct { @@ -270,6 +286,12 @@ type DruidSpec struct { // +kubebuilder:default:=true RollingDeploy bool `json:"rollingDeploy"` + // MiddleManagerDrainStrategy enables operator-managed draining before + // MiddleManager StatefulSet pods are rolled. If nil, MiddleManagers use the + // standard StatefulSet rolling update behavior. + // +optional + MiddleManagerDrainStrategy *MiddleManagerDrainStrategy `json:"middleManagerDrainStrategy,omitempty"` + // DefaultProbes If set to true this will add default probes (liveness / readiness / startup) for all druid components // but it won't override existing probes // +optional @@ -570,20 +592,32 @@ type DruidNodeTypeStatus struct { Reason string `json:"reason,omitempty"` } +// MiddleManagerDrainStatus reports an in-progress MiddleManager drain rollout. +type MiddleManagerDrainStatus struct { + StatefulSet string `json:"statefulSet,omitempty"` + Phase string `json:"phase,omitempty"` + PodName string `json:"podName,omitempty"` + PodOrdinal int32 `json:"podOrdinal,omitempty"` + OldPodUID string `json:"oldPodUID,omitempty"` + LastTransitionTime metav1.Time `json:"lastTransitionTime,omitempty"` + Message string `json:"message,omitempty"` +} + // DruidClusterStatus Defines the observed state of Druid. type DruidClusterStatus struct { // INSERT ADDITIONAL STATUS FIELD - define observed state of cluster // Important: Run "make" to regenerate code after modifying this file - DruidNodeStatus DruidNodeTypeStatus `json:"druidNodeStatus,omitempty"` - StatefulSets []string `json:"statefulSets,omitempty"` - Deployments []string `json:"deployments,omitempty"` - Services []string `json:"services,omitempty"` - ConfigMaps []string `json:"configMaps,omitempty"` - PodDisruptionBudgets []string `json:"podDisruptionBudgets,omitempty"` - Ingress []string `json:"ingress,omitempty"` - HPAutoScalers []string `json:"hpAutoscalers,omitempty"` - Pods []string `json:"pods,omitempty"` - PersistentVolumeClaims []string `json:"persistentVolumeClaims,omitempty"` + DruidNodeStatus DruidNodeTypeStatus `json:"druidNodeStatus,omitempty"` + StatefulSets []string `json:"statefulSets,omitempty"` + Deployments []string `json:"deployments,omitempty"` + Services []string `json:"services,omitempty"` + ConfigMaps []string `json:"configMaps,omitempty"` + PodDisruptionBudgets []string `json:"podDisruptionBudgets,omitempty"` + Ingress []string `json:"ingress,omitempty"` + HPAutoScalers []string `json:"hpAutoscalers,omitempty"` + Pods []string `json:"pods,omitempty"` + PersistentVolumeClaims []string `json:"persistentVolumeClaims,omitempty"` + MiddleManagerDrain *MiddleManagerDrainStatus `json:"middleManagerDrain,omitempty"` } // Druid is the Schema for the druids API. diff --git a/apis/druid/v1alpha1/zz_generated.deepcopy.go b/apis/druid/v1alpha1/zz_generated.deepcopy.go index 93eff59..bbfd644 100644 --- a/apis/druid/v1alpha1/zz_generated.deepcopy.go +++ b/apis/druid/v1alpha1/zz_generated.deepcopy.go @@ -181,6 +181,11 @@ func (in *DruidClusterStatus) DeepCopyInto(out *DruidClusterStatus) { *out = make([]string, len(*in)) copy(*out, *in) } + if in.MiddleManagerDrain != nil { + in, out := &in.MiddleManagerDrain, &out.MiddleManagerDrain + *out = new(MiddleManagerDrainStatus) + (*in).DeepCopyInto(*out) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DruidClusterStatus. @@ -691,6 +696,11 @@ func (in *DruidSpec) DeepCopyInto(out *DruidSpec) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + if in.MiddleManagerDrainStrategy != nil { + in, out := &in.MiddleManagerDrainStrategy, &out.MiddleManagerDrainStrategy + *out = new(MiddleManagerDrainStrategy) + **out = **in + } if in.Zookeeper != nil { in, out := &in.Zookeeper, &out.Zookeeper *out = new(ZookeeperSpec) @@ -769,6 +779,39 @@ func (in *MetadataStoreSpec) DeepCopy() *MetadataStoreSpec { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MiddleManagerDrainStatus) DeepCopyInto(out *MiddleManagerDrainStatus) { + *out = *in + in.LastTransitionTime.DeepCopyInto(&out.LastTransitionTime) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MiddleManagerDrainStatus. +func (in *MiddleManagerDrainStatus) DeepCopy() *MiddleManagerDrainStatus { + if in == nil { + return nil + } + out := new(MiddleManagerDrainStatus) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MiddleManagerDrainStrategy) DeepCopyInto(out *MiddleManagerDrainStrategy) { + *out = *in + out.DrainTimeout = in.DrainTimeout + out.PodReadyTimeout = in.PodReadyTimeout +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MiddleManagerDrainStrategy. +func (in *MiddleManagerDrainStrategy) DeepCopy() *MiddleManagerDrainStrategy { + if in == nil { + return nil + } + out := new(MiddleManagerDrainStrategy) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *ZookeeperSpec) DeepCopyInto(out *ZookeeperSpec) { *out = *in diff --git a/chart/crds/druid.apache.org_druids.yaml b/chart/crds/druid.apache.org_druids.yaml index e2971b3..be95dcc 100644 --- a/chart/crds/druid.apache.org_druids.yaml +++ b/chart/crds/druid.apache.org_druids.yaml @@ -1980,6 +1980,25 @@ spec: stastd documentation is described in the following documentation: https://druid.apache.org/docs/latest/development/extensions-contrib/statsd.html type: string + middleManagerDrainStrategy: + description: |- + MiddleManagerDrainStrategy enables operator-managed draining before + MiddleManager StatefulSet pods are rolled. If nil, MiddleManagers use the + standard StatefulSet rolling update behavior. + properties: + drainTimeout: + default: 1h + description: |- + DrainTimeout is the maximum time to wait for streaming ingestion tasks to + drain before allowing Kubernetes to replace the MiddleManager pod. + type: string + podReadyTimeout: + default: 30m + description: |- + PodReadyTimeout is the maximum time to wait for Kubernetes to replace the + pod and for the replacement to become ready on the target StatefulSet revision. + type: string + type: object nodeSelector: additionalProperties: type: string @@ -11877,6 +11896,27 @@ spec: items: type: string type: array + middleManagerDrain: + description: MiddleManagerDrainStatus reports an in-progress MiddleManager + drain rollout. + properties: + lastTransitionTime: + format: date-time + type: string + message: + type: string + oldPodUID: + type: string + phase: + type: string + podName: + type: string + podOrdinal: + format: int32 + type: integer + statefulSet: + type: string + type: object persistentVolumeClaims: items: type: string diff --git a/config/crd/bases/druid.apache.org_druids.yaml b/config/crd/bases/druid.apache.org_druids.yaml index e2971b3..be95dcc 100644 --- a/config/crd/bases/druid.apache.org_druids.yaml +++ b/config/crd/bases/druid.apache.org_druids.yaml @@ -1980,6 +1980,25 @@ spec: stastd documentation is described in the following documentation: https://druid.apache.org/docs/latest/development/extensions-contrib/statsd.html type: string + middleManagerDrainStrategy: + description: |- + MiddleManagerDrainStrategy enables operator-managed draining before + MiddleManager StatefulSet pods are rolled. If nil, MiddleManagers use the + standard StatefulSet rolling update behavior. + properties: + drainTimeout: + default: 1h + description: |- + DrainTimeout is the maximum time to wait for streaming ingestion tasks to + drain before allowing Kubernetes to replace the MiddleManager pod. + type: string + podReadyTimeout: + default: 30m + description: |- + PodReadyTimeout is the maximum time to wait for Kubernetes to replace the + pod and for the replacement to become ready on the target StatefulSet revision. + type: string + type: object nodeSelector: additionalProperties: type: string @@ -11877,6 +11896,27 @@ spec: items: type: string type: array + middleManagerDrain: + description: MiddleManagerDrainStatus reports an in-progress MiddleManager + drain rollout. + properties: + lastTransitionTime: + format: date-time + type: string + message: + type: string + oldPodUID: + type: string + phase: + type: string + podName: + type: string + podOrdinal: + format: int32 + type: integer + statefulSet: + type: string + type: object persistentVolumeClaims: items: type: string diff --git a/controllers/druid/handler.go b/controllers/druid/handler.go index b236ef4..b4499cc 100644 --- a/controllers/druid/handler.go +++ b/controllers/druid/handler.go @@ -147,6 +147,12 @@ func deployDruidCluster(ctx context.Context, sdk client.Client, m *v1alpha1.Drui nodeSpec.Ports = append(nodeSpec.Ports, v1.ContainerPort{ContainerPort: nodeSpec.DruidPort, Name: "druid-port"}) + if m.Spec.MiddleManagerDrainStrategy != nil && nodeSpec.NodeType == middleManager && nodeSpec.Kind == "Deployment" { + logger.Info("MiddleManager drain strategy is only supported for StatefulSet workloads; using standard Deployment rollout", + "nodeSpecUniqueStr", nodeSpecUniqueStr, + "namespace", m.Namespace) + } + if nodeSpec.Kind == "Deployment" { if deployCreateUpdateStatus, err := sdkCreateOrUpdateAsNeeded(ctx, sdk, func() (object, error) { @@ -183,13 +189,22 @@ func deployDruidCluster(ctx context.Context, sdk client.Client, m *v1alpha1.Drui } } + if m.Generation > 1 && nodeSpec.NodeType == middleManager && m.Spec.MiddleManagerDrainStrategy == nil { + cleanupStaleMiddleManagerDrainState(ctx, sdk, m, nodeSpecUniqueStr, emitEvents) + } + + stsUpdaterFn := noopUpdaterFn + if m.Generation > 1 && nodeSpec.NodeType == middleManager && m.Spec.MiddleManagerDrainStrategy != nil { + stsUpdaterFn = middleManagerDrainStatefulSetUpdaterFn + } + // Create/Update StatefulSet if stsCreateUpdateStatus, err := sdkCreateOrUpdateAsNeeded(ctx, sdk, func() (object, error) { return makeStatefulSet(&nodeSpec, m, lm, nodeSpecUniqueStr, fmt.Sprintf("%s-%s", commonConfigSHA, nodeConfigSHA), firstServiceName) }, func() object { return &appsv1.StatefulSet{} }, - statefulSetIsEquals, noopUpdaterFn, m, statefulSetNames, emitEvents); err != nil { + statefulSetIsEquals, stsUpdaterFn, m, statefulSetNames, emitEvents); err != nil { return err } else if m.Spec.RollingDeploy { @@ -208,6 +223,16 @@ func deployDruidCluster(ctx context.Context, sdk client.Client, m *v1alpha1.Drui //Check StatefulSet rolling update status, if in-progress then stop here done, err := isObjFullyDeployed(ctx, sdk, nodeSpec, nodeSpecUniqueStr, m, func() object { return &appsv1.StatefulSet{} }, emitEvents) if !done { + if nodeSpec.NodeType == middleManager && m.Spec.MiddleManagerDrainStrategy != nil { + if err := processMiddleManagerRollingRestart(ctx, sdk, m, nodeSpecUniqueStr, m.Spec.MiddleManagerDrainStrategy, emitEvents); err != nil { + return err + } + // The drain state machine owns the StatefulSet partition while a + // MiddleManager rollout is in progress. Return immediately so + // sdkCreateOrUpdateAsNeeded is not re-entered in this reconcile + // and cannot overwrite the partition selected for the current phase. + return nil + } return err } } @@ -273,6 +298,7 @@ func deployDruidCluster(ctx context.Context, sdk client.Client, m *v1alpha1.Drui //update status and delete unwanted resources updatedStatus := v1alpha1.DruidClusterStatus{} + updatedStatus.MiddleManagerDrain = m.Status.MiddleManagerDrain updatedStatus.StatefulSets = deleteUnusedResources(ctx, sdk, m, statefulSetNames, ls, func() objectList { return &appsv1.StatefulSetList{} }, diff --git a/controllers/druid/middle_manager_drain.go b/controllers/druid/middle_manager_drain.go new file mode 100644 index 0000000..712cba5 --- /dev/null +++ b/controllers/druid/middle_manager_drain.go @@ -0,0 +1,697 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ +package druid + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "regexp" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/apache/druid-operator/apis/druid/v1alpha1" + druidapi "github.com/apache/druid-operator/pkg/druidapi" + internalhttp "github.com/apache/druid-operator/pkg/http" + appsv1 "k8s.io/api/apps/v1" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +const ( + middleManagerDrainPhaseDraining = "Draining" + middleManagerDrainPhaseWaitingForPod = "WaitingForPod" + middleManagerDrainPhaseBlocked = "Blocked" + + defaultMiddleManagerDrainTimeout = time.Hour + defaultMiddleManagerPodReadyTimeout = 30 * time.Minute + statefulSetPartitionResetValue = int32(0) +) + +var ( + middleManagerDrainStates = sync.Map{} + workerHostnamePattern = regexp.MustCompile(`^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$`) +) + +type middleManagerDrainState struct { + Phase string + PodName string + PodOrdinal int32 + OldPodUID string + LastUpdateTime time.Time +} + +type middleManagerDrainConfig struct { + DrainTimeout time.Duration + PodReadyTimeout time.Duration +} + +type middleManagerDruidAPI interface { + DisableWorker(workerHost string) error + EnableWorker(workerHost string) error + GetTaskPayload(taskID string) (*taskPayloadResponse, error) + TriggerTaskGroupHandoff(supervisorID string, taskGroupIDs []int) error + ExecuteSQL(query string) ([]byte, error) +} + +type middleManagerDruidHTTPAPI struct { + baseURL string + httpClient internalhttp.DruidHTTP +} + +type druidSQLRequest struct { + Query string `json:"query"` +} + +type runningTaskInfo struct { + TaskID string `json:"task_id"` + DataSource string `json:"datasource"` + Type string `json:"type"` +} + +type taskPayloadResponse struct { + Task string `json:"task"` + Payload struct { + DataSource string `json:"dataSource"` + IOConfig struct { + TaskGroupID *int `json:"taskGroupId"` + } `json:"ioConfig"` + } `json:"payload"` +} + +type taskGroupHandoffRequest struct { + TaskGroupIDs []int `json:"taskGroupIds"` +} + +func newMiddleManagerDruidAPI(ctx context.Context, sdk client.Client, drd *v1alpha1.Druid) (middleManagerDruidAPI, error) { + routerURL, err := druidapi.GetRouterSvcUrl(drd.Namespace, drd.Name, sdk) + if err != nil { + return nil, fmt.Errorf("failed to discover Druid router service: %w", err) + } + + basicAuth, err := druidapi.GetAuthCreds(ctx, sdk, drd.Spec.Auth) + if err != nil { + return nil, fmt.Errorf("failed to get Druid API credentials: %w", err) + } + + return &middleManagerDruidHTTPAPI{ + baseURL: routerURL, + httpClient: internalhttp.NewHTTPClient( + &http.Client{}, + &internalhttp.Auth{BasicAuth: basicAuth}, + ), + }, nil +} + +func (c *middleManagerDruidHTTPAPI) DisableWorker(workerHost string) error { + path := fmt.Sprintf("%s/%s/disable", druidapi.MakePath(c.baseURL, "indexer", "worker"), url.PathEscape(workerHost)) + resp, err := c.httpClient.Do(http.MethodPost, path, nil) + if err != nil { + return fmt.Errorf("failed to call disable API for worker %q: %w", workerHost, err) + } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("disable API returned status %d for worker %q: %s", resp.StatusCode, workerHost, resp.ResponseBody) + } + return nil +} + +func (c *middleManagerDruidHTTPAPI) EnableWorker(workerHost string) error { + path := fmt.Sprintf("%s/%s/enable", druidapi.MakePath(c.baseURL, "indexer", "worker"), url.PathEscape(workerHost)) + resp, err := c.httpClient.Do(http.MethodPost, path, nil) + if err != nil { + return fmt.Errorf("failed to call enable API for worker %q: %w", workerHost, err) + } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("enable API returned status %d for worker %q: %s", resp.StatusCode, workerHost, resp.ResponseBody) + } + return nil +} + +func (c *middleManagerDruidHTTPAPI) GetTaskPayload(taskID string) (*taskPayloadResponse, error) { + path := fmt.Sprintf("%s/%s", druidapi.MakePath(c.baseURL, "indexer", "task"), url.PathEscape(taskID)) + resp, err := c.httpClient.Do(http.MethodGet, path, nil) + if err != nil { + return nil, fmt.Errorf("failed to fetch task payload for %q: %w", taskID, err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("task payload API returned status %d for %q: %s", resp.StatusCode, taskID, resp.ResponseBody) + } + + var payload taskPayloadResponse + if err := json.Unmarshal([]byte(resp.ResponseBody), &payload); err != nil { + return nil, fmt.Errorf("failed to decode task payload for %q: %w", taskID, err) + } + return &payload, nil +} + +func (c *middleManagerDruidHTTPAPI) TriggerTaskGroupHandoff(supervisorID string, taskGroupIDs []int) error { + path := fmt.Sprintf("%s/%s/taskGroups/handoff", druidapi.MakePath(c.baseURL, "indexer", "supervisor"), url.PathEscape(supervisorID)) + reqBody, err := json.Marshal(taskGroupHandoffRequest{TaskGroupIDs: taskGroupIDs}) + if err != nil { + return fmt.Errorf("failed to marshal handoff request for %q: %w", supervisorID, err) + } + + resp, err := c.httpClient.Do(http.MethodPost, path, reqBody) + if err != nil { + return fmt.Errorf("failed to trigger handoff for supervisor %q: %w", supervisorID, err) + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + return fmt.Errorf("handoff API returned status %d for supervisor %q: %s", resp.StatusCode, supervisorID, resp.ResponseBody) + } + return nil +} + +func (c *middleManagerDruidHTTPAPI) ExecuteSQL(query string) ([]byte, error) { + reqBody, err := json.Marshal(druidSQLRequest{Query: query}) + if err != nil { + return nil, fmt.Errorf("failed to marshal SQL request: %w", err) + } + + resp, err := c.httpClient.Do(http.MethodPost, druidapi.MakeSQLPath(c.baseURL), reqBody) + if err != nil { + return nil, fmt.Errorf("failed to execute Druid SQL: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Druid SQL API returned status %d: %s", resp.StatusCode, resp.ResponseBody) + } + return []byte(resp.ResponseBody), nil +} + +func normalizeMiddleManagerDrainConfig(strategy *v1alpha1.MiddleManagerDrainStrategy) middleManagerDrainConfig { + config := middleManagerDrainConfig{ + DrainTimeout: defaultMiddleManagerDrainTimeout, + PodReadyTimeout: defaultMiddleManagerPodReadyTimeout, + } + if strategy == nil { + return config + } + if strategy.DrainTimeout.Duration > 0 { + config.DrainTimeout = strategy.DrainTimeout.Duration + } + if strategy.PodReadyTimeout.Duration > 0 { + config.PodReadyTimeout = strategy.PodReadyTimeout.Duration + } + return config +} + +func middleManagerDrainStateKey(namespace, druidCR, statefulSetName string) string { + return fmt.Sprintf("%s/%s/%s", namespace, druidCR, statefulSetName) +} + +func getMiddleManagerDrainState(namespace, druidCR, statefulSetName string) (*middleManagerDrainState, bool) { + key := middleManagerDrainStateKey(namespace, druidCR, statefulSetName) + if value, exists := middleManagerDrainStates.Load(key); exists { + return value.(*middleManagerDrainState), true + } + return nil, false +} + +func loadMiddleManagerDrainState(drd *v1alpha1.Druid, statefulSetName string) (*middleManagerDrainState, bool) { + if state, exists := getMiddleManagerDrainState(drd.Namespace, drd.Name, statefulSetName); exists { + return state, true + } + + status := drd.Status.MiddleManagerDrain + if status == nil || status.StatefulSet != statefulSetName || status.Phase == "" || status.PodName == "" { + return nil, false + } + + state := &middleManagerDrainState{ + Phase: status.Phase, + PodName: status.PodName, + PodOrdinal: status.PodOrdinal, + OldPodUID: status.OldPodUID, + LastUpdateTime: status.LastTransitionTime.Time, + } + if state.LastUpdateTime.IsZero() { + state.LastUpdateTime = time.Now() + } + middleManagerDrainStates.Store(middleManagerDrainStateKey(drd.Namespace, drd.Name, statefulSetName), state) + return state, true +} + +func setMiddleManagerDrainState(ctx context.Context, sdk client.Client, drd *v1alpha1.Druid, statefulSetName string, state *middleManagerDrainState, message string, emitEvent EventEmitter) error { + if state.LastUpdateTime.IsZero() { + state.LastUpdateTime = time.Now() + } + middleManagerDrainStates.Store(middleManagerDrainStateKey(drd.Namespace, drd.Name, statefulSetName), state) + + return patchMiddleManagerDrainStatus(ctx, sdk, drd, &v1alpha1.MiddleManagerDrainStatus{ + StatefulSet: statefulSetName, + Phase: state.Phase, + PodName: state.PodName, + PodOrdinal: state.PodOrdinal, + OldPodUID: state.OldPodUID, + LastTransitionTime: metav1.NewTime(state.LastUpdateTime), + Message: message, + }, emitEvent) +} + +func clearMiddleManagerDrainState(ctx context.Context, sdk client.Client, drd *v1alpha1.Druid, statefulSetName string, emitEvent EventEmitter) error { + middleManagerDrainStates.Delete(middleManagerDrainStateKey(drd.Namespace, drd.Name, statefulSetName)) + if drd.Status.MiddleManagerDrain == nil || drd.Status.MiddleManagerDrain.StatefulSet != statefulSetName { + return nil + } + return patchMiddleManagerDrainStatus(ctx, sdk, drd, nil, emitEvent) +} + +func patchMiddleManagerDrainStatus(ctx context.Context, sdk client.Client, drd *v1alpha1.Druid, status *v1alpha1.MiddleManagerDrainStatus, emitEvent EventEmitter) error { + updatedStatus := drd.Status + updatedStatus.MiddleManagerDrain = status + if err := druidClusterStatusPatcher(ctx, sdk, updatedStatus, drd, emitEvent); err != nil { + return err + } + drd.Status = updatedStatus + return nil +} + +func updateStatefulSetPartition(ctx context.Context, sdk client.Client, statefulSetName, namespace string, partition int32) error { + var sts appsv1.StatefulSet + if err := sdk.Get(ctx, types.NamespacedName{Name: statefulSetName, Namespace: namespace}, &sts); err != nil { + return fmt.Errorf("failed to get StatefulSet %q in namespace %q: %w", statefulSetName, namespace, err) + } + + if sts.Spec.UpdateStrategy.RollingUpdate == nil { + sts.Spec.UpdateStrategy.Type = appsv1.RollingUpdateStatefulSetStrategyType + sts.Spec.UpdateStrategy.RollingUpdate = &appsv1.RollingUpdateStatefulSetStrategy{} + } + if sts.Spec.UpdateStrategy.RollingUpdate.Partition != nil && *sts.Spec.UpdateStrategy.RollingUpdate.Partition == partition { + return nil + } + + sts.Spec.UpdateStrategy.RollingUpdate.Partition = &partition + return sdk.Update(ctx, &sts) +} + +func middleManagerDrainStatefulSetUpdaterFn(prev, curr object) { + currSts, ok := curr.(*appsv1.StatefulSet) + if !ok { + return + } + + replicas := int32(1) + if currSts.Spec.Replicas != nil { + replicas = *currSts.Spec.Replicas + } + if currSts.Spec.UpdateStrategy.RollingUpdate == nil { + currSts.Spec.UpdateStrategy.Type = appsv1.RollingUpdateStatefulSetStrategyType + currSts.Spec.UpdateStrategy.RollingUpdate = &appsv1.RollingUpdateStatefulSetStrategy{} + } + currSts.Spec.UpdateStrategy.RollingUpdate.Partition = &replicas +} + +func cleanupStaleMiddleManagerDrainState(ctx context.Context, sdk client.Client, drd *v1alpha1.Druid, statefulSetName string, emitEvent EventEmitter) { + state, hasState := loadMiddleManagerDrainState(drd, statefulSetName) + if !hasState { + return + } + + if hasState { + api, err := newMiddleManagerDruidAPI(ctx, sdk, drd) + if err != nil { + logger.Error(err, "Failed to create Druid API client while cleaning stale MiddleManager drain state", "statefulSet", statefulSetName) + } else { + druidPort, portErr := getDruidPortFromStatefulSet(ctx, sdk, statefulSetName, drd.Namespace) + if portErr != nil { + logger.Error(portErr, "Failed to get Druid port while cleaning stale MiddleManager drain state", "statefulSet", statefulSetName) + } else { + workerHost := buildMiddleManagerWorkerHost(state.PodName, statefulSetName, drd.Namespace, druidPort) + if err := api.EnableWorker(workerHost); err != nil { + logger.Error(err, "Failed to re-enable MiddleManager while cleaning stale drain state", "pod", state.PodName, "statefulSet", statefulSetName) + } + } + } + } + + if err := clearMiddleManagerDrainState(ctx, sdk, drd, statefulSetName, emitEvent); err != nil { + logger.Error(err, "Failed to clear stale MiddleManager drain status", "statefulSet", statefulSetName) + } + if err := updateStatefulSetPartition(ctx, sdk, statefulSetName, drd.Namespace, statefulSetPartitionResetValue); err != nil { + logger.Error(err, "Failed to reset StatefulSet partition while cleaning stale MiddleManager drain state", "statefulSet", statefulSetName) + } +} + +func processMiddleManagerRollingRestart(ctx context.Context, sdk client.Client, drd *v1alpha1.Druid, statefulSetName string, strategy *v1alpha1.MiddleManagerDrainStrategy, emitEvent EventEmitter) error { + var sts appsv1.StatefulSet + if err := sdk.Get(ctx, types.NamespacedName{Name: statefulSetName, Namespace: drd.Namespace}, &sts); err != nil { + return fmt.Errorf("failed to get StatefulSet %q in namespace %q: %w", statefulSetName, drd.Namespace, err) + } + + if sts.Status.CurrentRevision == sts.Status.UpdateRevision { + if err := clearMiddleManagerDrainState(ctx, sdk, drd, statefulSetName, emitEvent); err != nil { + return err + } + return updateStatefulSetPartition(ctx, sdk, statefulSetName, drd.Namespace, statefulSetPartitionResetValue) + } + + config := normalizeMiddleManagerDrainConfig(strategy) + totalReplicas := int32(1) + if sts.Spec.Replicas != nil { + totalReplicas = *sts.Spec.Replicas + } + + state, hasState := loadMiddleManagerDrainState(drd, statefulSetName) + if !hasState { + if err := updateStatefulSetPartition(ctx, sdk, statefulSetName, drd.Namespace, totalReplicas); err != nil { + return fmt.Errorf("failed to block MiddleManager StatefulSet rolling update: %w", err) + } + } + + api, err := newMiddleManagerDruidAPI(ctx, sdk, drd) + if err != nil { + return err + } + + druidPort, err := getDruidPortFromStatefulSet(ctx, sdk, statefulSetName, drd.Namespace) + if err != nil { + return err + } + + if hasState { + return continueMiddleManagerDrainCycle(ctx, sdk, drd, &sts, api, state, druidPort, config, emitEvent) + } + return startMiddleManagerDrainCycle(ctx, sdk, drd, &sts, api, druidPort, emitEvent) +} + +func startMiddleManagerDrainCycle(ctx context.Context, sdk client.Client, drd *v1alpha1.Druid, sts *appsv1.StatefulSet, api middleManagerDruidAPI, druidPort int32, emitEvent EventEmitter) error { + outdatedPods, err := getOutdatedMiddleManagerPods(ctx, sdk, sts.Name, sts.Namespace, sts.Status.CurrentRevision) + if err != nil { + return err + } + if len(outdatedPods) == 0 { + return nil + } + + sortPodsDescending(outdatedPods) + targetPod := outdatedPods[0] + podOrdinal := extractPodOrdinal(targetPod.Name) + if podOrdinal < 0 { + return fmt.Errorf("could not extract ordinal from MiddleManager pod name %q", targetPod.Name) + } + + workerHost := buildMiddleManagerWorkerHost(targetPod.Name, sts.Name, sts.Namespace, druidPort) + if targetPod.Labels["controller-revision-hash"] == sts.Status.UpdateRevision { + if err := api.EnableWorker(workerHost); err != nil { + logger.Error(err, "Failed to enable already-updated MiddleManager pod", "pod", targetPod.Name) + } + return nil + } + + if err := drainMiddleManager(api, workerHost); err != nil { + return err + } + + return setMiddleManagerDrainState(ctx, sdk, drd, sts.Name, &middleManagerDrainState{ + Phase: middleManagerDrainPhaseDraining, + PodName: targetPod.Name, + PodOrdinal: podOrdinal, + }, "Drain initiated; waiting for streaming ingestion tasks to finish", emitEvent) +} + +func continueMiddleManagerDrainCycle(ctx context.Context, sdk client.Client, drd *v1alpha1.Druid, sts *appsv1.StatefulSet, api middleManagerDruidAPI, state *middleManagerDrainState, druidPort int32, config middleManagerDrainConfig, emitEvent EventEmitter) error { + workerHost := buildMiddleManagerWorkerHost(state.PodName, sts.Name, sts.Namespace, druidPort) + elapsed := time.Since(state.LastUpdateTime) + + switch state.Phase { + case middleManagerDrainPhaseDraining: + if elapsed < config.DrainTimeout { + drained, err := isMiddleManagerDrained(api, workerHost) + if err != nil { + logger.Error(err, "Failed to check MiddleManager drain status; will retry", "pod", state.PodName) + return nil + } + if !drained { + return setMiddleManagerDrainState(ctx, sdk, drd, sts.Name, state, "Waiting for streaming ingestion tasks to drain", emitEvent) + } + } + + oldPodUID := "" + var oldPod v1.Pod + if err := sdk.Get(ctx, types.NamespacedName{Name: state.PodName, Namespace: sts.Namespace}, &oldPod); err == nil { + oldPodUID = string(oldPod.UID) + } + if err := updateStatefulSetPartition(ctx, sdk, sts.Name, sts.Namespace, state.PodOrdinal); err != nil { + return fmt.Errorf("failed to lower StatefulSet partition for pod %q: %w", state.PodName, err) + } + return setMiddleManagerDrainState(ctx, sdk, drd, sts.Name, &middleManagerDrainState{ + Phase: middleManagerDrainPhaseWaitingForPod, + PodName: state.PodName, + PodOrdinal: state.PodOrdinal, + OldPodUID: oldPodUID, + }, "Drain complete; waiting for replacement pod to become ready", emitEvent) + + case middleManagerDrainPhaseWaitingForPod: + if elapsed >= config.PodReadyTimeout { + blockedState := *state + blockedState.Phase = middleManagerDrainPhaseBlocked + blockedState.LastUpdateTime = time.Time{} + message := fmt.Sprintf("Timed out after %s waiting for replacement pod to become ready", config.PodReadyTimeout) + if err := setMiddleManagerDrainState(ctx, sdk, drd, sts.Name, &blockedState, message, emitEvent); err != nil { + return err + } + return fmt.Errorf("MiddleManager pod %q did not become ready before timeout %s", state.PodName, config.PodReadyTimeout) + } + + var pod v1.Pod + if err := sdk.Get(ctx, types.NamespacedName{Name: state.PodName, Namespace: sts.Namespace}, &pod); err != nil { + return nil + } + if state.OldPodUID != "" && string(pod.UID) == state.OldPodUID { + return nil + } + if !isPodReady(pod) { + return nil + } + if pod.Labels["controller-revision-hash"] != sts.Status.UpdateRevision { + if err := clearMiddleManagerDrainState(ctx, sdk, drd, sts.Name, emitEvent); err != nil { + return err + } + return nil + } + if err := api.EnableWorker(workerHost); err != nil { + logger.Error(err, "Failed to re-enable MiddleManager pod; will retry", "pod", state.PodName) + return nil + } + return clearMiddleManagerDrainState(ctx, sdk, drd, sts.Name, emitEvent) + + case middleManagerDrainPhaseBlocked: + message := "" + if drd.Status.MiddleManagerDrain != nil { + message = drd.Status.MiddleManagerDrain.Message + } + return fmt.Errorf("MiddleManager drain rollout is blocked for pod %q: %s", state.PodName, message) + + default: + return clearMiddleManagerDrainState(ctx, sdk, drd, sts.Name, emitEvent) + } +} + +func drainMiddleManager(api middleManagerDruidAPI, workerHost string) error { + if err := api.DisableWorker(workerHost); err != nil { + return fmt.Errorf("failed to disable MiddleManager: %w", err) + } + + runningTasks, err := getRunningTasksFromSQL(api, workerHost) + if err != nil { + return fmt.Errorf("failed to get running tasks after disabling MiddleManager: %w", err) + } + if len(runningTasks) == 0 { + return nil + } + + handoffs, err := resolveHandoffsForWorker(api, runningTasks) + if err != nil { + return err + } + for supervisorID, taskGroupIDs := range handoffs { + if len(taskGroupIDs) == 0 { + continue + } + if err := api.TriggerTaskGroupHandoff(supervisorID, taskGroupIDs); err != nil { + return fmt.Errorf("failed to trigger handoff for supervisor %q: %w", supervisorID, err) + } + } + return nil +} + +func resolveHandoffsForWorker(api middleManagerDruidAPI, runningTaskIDs []string) (map[string][]int, error) { + supervisorToGroupIDs := map[string]map[int]bool{} + for _, taskID := range runningTaskIDs { + payload, err := api.GetTaskPayload(taskID) + if err != nil { + logger.Error(err, "Failed to fetch task payload while resolving MiddleManager handoff", "taskID", taskID) + continue + } + if payload.Payload.DataSource == "" || payload.Payload.IOConfig.TaskGroupID == nil { + continue + } + if supervisorToGroupIDs[payload.Payload.DataSource] == nil { + supervisorToGroupIDs[payload.Payload.DataSource] = map[int]bool{} + } + supervisorToGroupIDs[payload.Payload.DataSource][*payload.Payload.IOConfig.TaskGroupID] = true + } + + result := map[string][]int{} + for supervisorID, groupIDSet := range supervisorToGroupIDs { + ids := make([]int, 0, len(groupIDSet)) + for id := range groupIDSet { + ids = append(ids, id) + } + sort.Ints(ids) + result[supervisorID] = ids + } + return result, nil +} + +func getRunningTasksFromSQL(api middleManagerDruidAPI, workerHost string) ([]string, error) { + hostname, err := validateWorkerHostnameForSQL(workerHost) + if err != nil { + return nil, err + } + + query := fmt.Sprintf(`SELECT "task_id", "datasource", "type" FROM sys.tasks WHERE "runner_status" = 'RUNNING' AND "location" LIKE '%s:%%'`, hostname) + body, err := api.ExecuteSQL(query) + if err != nil { + return nil, err + } + + var rows []runningTaskInfo + if err := json.Unmarshal(body, &rows); err != nil { + return nil, fmt.Errorf("failed to decode running tasks SQL response: %w", err) + } + + taskIDs := make([]string, 0, len(rows)) + for _, row := range rows { + if row.TaskID != "" { + taskIDs = append(taskIDs, row.TaskID) + } + } + return taskIDs, nil +} + +func isMiddleManagerDrained(api middleManagerDruidAPI, workerHost string) (bool, error) { + hostname, err := validateWorkerHostnameForSQL(workerHost) + if err != nil { + return false, err + } + + query := fmt.Sprintf(`SELECT COUNT(*) AS "cnt" FROM sys.tasks WHERE "runner_status" = 'RUNNING' AND "location" LIKE '%s:%%' AND "type" IN ('index_kafka', 'index_kinesis')`, hostname) + body, err := api.ExecuteSQL(query) + if err != nil { + return false, err + } + + var rows []struct { + Cnt int `json:"cnt"` + } + if err := json.Unmarshal(body, &rows); err != nil { + return false, fmt.Errorf("failed to decode drain check SQL response: %w", err) + } + return len(rows) == 0 || rows[0].Cnt == 0, nil +} + +func validateWorkerHostnameForSQL(workerHost string) (string, error) { + hostname := stripPort(workerHost) + if !workerHostnamePattern.MatchString(hostname) { + return "", fmt.Errorf("invalid MiddleManager worker hostname %q", hostname) + } + return hostname, nil +} + +func stripPort(hostPort string) string { + idx := strings.LastIndex(hostPort, ":") + if idx < 0 { + return hostPort + } + return hostPort[:idx] +} + +func buildMiddleManagerWorkerHost(podName, serviceName, namespace string, port int32) string { + return fmt.Sprintf("%s.%s.%s.svc.cluster.local:%d", podName, serviceName, namespace, port) +} + +func getDruidPortFromStatefulSet(ctx context.Context, sdk client.Client, statefulSetName, namespace string) (int32, error) { + var sts appsv1.StatefulSet + if err := sdk.Get(ctx, types.NamespacedName{Name: statefulSetName, Namespace: namespace}, &sts); err != nil { + return 0, err + } + for _, container := range sts.Spec.Template.Spec.Containers { + for _, port := range container.Ports { + if port.Name == "druid-port" { + return port.ContainerPort, nil + } + } + } + return 0, fmt.Errorf("druid-port not found in StatefulSet %q pod template", statefulSetName) +} + +func getOutdatedMiddleManagerPods(ctx context.Context, sdk client.Client, statefulSetName, namespace, currentRevision string) ([]v1.Pod, error) { + var podList v1.PodList + if err := sdk.List(ctx, &podList, client.InNamespace(namespace), client.MatchingLabels{ + "nodeSpecUniqueStr": statefulSetName, + }); err != nil { + return nil, fmt.Errorf("failed to list MiddleManager pods: %w", err) + } + + outdatedPods := make([]v1.Pod, 0, len(podList.Items)) + for _, pod := range podList.Items { + if pod.Labels["controller-revision-hash"] == currentRevision { + outdatedPods = append(outdatedPods, pod) + } + } + return outdatedPods, nil +} + +func sortPodsDescending(pods []v1.Pod) { + sort.SliceStable(pods, func(i, j int) bool { + return extractPodOrdinal(pods[i].Name) > extractPodOrdinal(pods[j].Name) + }) +} + +func extractPodOrdinal(name string) int32 { + re := regexp.MustCompile(`\d+$`) + match := re.FindString(name) + if match == "" { + return -1 + } + num, err := strconv.Atoi(match) + if err != nil { + return -1 + } + return int32(num) +} + +func isPodReady(pod v1.Pod) bool { + if pod.Status.Phase != v1.PodRunning { + return false + } + for _, condition := range pod.Status.Conditions { + if condition.Type == v1.PodReady && condition.Status == v1.ConditionTrue { + return true + } + } + return false +} diff --git a/controllers/druid/middle_manager_drain_test.go b/controllers/druid/middle_manager_drain_test.go new file mode 100644 index 0000000..acd69cd --- /dev/null +++ b/controllers/druid/middle_manager_drain_test.go @@ -0,0 +1,236 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ +package druid + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/apache/druid-operator/apis/druid/v1alpha1" + internalhttp "github.com/apache/druid-operator/pkg/http" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + appsv1 "k8s.io/api/apps/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/scheme" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +type fakeMiddleManagerDruidAPI struct { + sqlResponses map[string][]byte + payloads map[string]*taskPayloadResponse + disabled []string + enabled []string + handoffs map[string][]int +} + +func (f *fakeMiddleManagerDruidAPI) DisableWorker(workerHost string) error { + f.disabled = append(f.disabled, workerHost) + return nil +} + +func (f *fakeMiddleManagerDruidAPI) EnableWorker(workerHost string) error { + f.enabled = append(f.enabled, workerHost) + return nil +} + +func (f *fakeMiddleManagerDruidAPI) GetTaskPayload(taskID string) (*taskPayloadResponse, error) { + return f.payloads[taskID], nil +} + +func (f *fakeMiddleManagerDruidAPI) TriggerTaskGroupHandoff(supervisorID string, taskGroupIDs []int) error { + if f.handoffs == nil { + f.handoffs = map[string][]int{} + } + f.handoffs[supervisorID] = taskGroupIDs + return nil +} + +func (f *fakeMiddleManagerDruidAPI) ExecuteSQL(query string) ([]byte, error) { + return f.sqlResponses["default"], nil +} + +func taskPayload(supervisorID string, taskGroupID int) *taskPayloadResponse { + payload := &taskPayloadResponse{} + payload.Payload.DataSource = supervisorID + payload.Payload.IOConfig.TaskGroupID = &taskGroupID + return payload +} + +func TestMiddleManagerDruidHTTPAPIEscapesPathSegments(t *testing.T) { + requestURIs := []string{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestURIs = append(requestURIs, r.RequestURI) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"task":"task/a","payload":{"dataSource":"wiki","ioConfig":{"taskGroupId":1}}}`)) + })) + defer server.Close() + + api := &middleManagerDruidHTTPAPI{ + baseURL: server.URL, + httpClient: internalhttp.NewHTTPClient(server.Client(), &internalhttp.Auth{}), + } + + require.NoError(t, api.DisableWorker("mm-0.druid-mm.ns.svc.cluster.local:8091")) + _, err := api.GetTaskPayload("task/with/slashes") + require.NoError(t, err) + require.NoError(t, api.TriggerTaskGroupHandoff("supervisor/with/slash", []int{1})) + + require.Len(t, requestURIs, 3) + assert.Contains(t, requestURIs[0], "mm-0.druid-mm.ns.svc.cluster.local:8091") + assert.Contains(t, requestURIs[1], "task%2Fwith%2Fslashes") + assert.Contains(t, requestURIs[2], "supervisor%2Fwith%2Fslash") +} + +func TestDrainMiddleManagerTriggersDeduplicatedHandoffs(t *testing.T) { + api := &fakeMiddleManagerDruidAPI{ + sqlResponses: map[string][]byte{ + "default": []byte(`[ + {"task_id":"task-0","datasource":"wiki","type":"index_kafka"}, + {"task_id":"task-1","datasource":"wiki","type":"index_kafka"}, + {"task_id":"task-duplicate","datasource":"wiki","type":"index_kafka"} + ]`), + }, + payloads: map[string]*taskPayloadResponse{ + "task-0": taskPayload("wiki", 1), + "task-1": taskPayload("wiki", 2), + "task-duplicate": taskPayload("wiki", 1), + }, + } + + workerHost := "mm-0.druid-mm.druid.svc.cluster.local:8091" + require.NoError(t, drainMiddleManager(api, workerHost)) + + assert.Equal(t, []string{workerHost}, api.disabled) + assert.Equal(t, []int{1, 2}, api.handoffs["wiki"]) +} + +func TestValidateWorkerHostnameForSQLRejectsUnsafeHost(t *testing.T) { + _, err := validateWorkerHostnameForSQL("mm-0.druid-mm.druid.svc.cluster.local:8091") + require.NoError(t, err) + + _, err = validateWorkerHostnameForSQL("mm-0.bad'host.druid.svc.cluster.local:8091") + require.Error(t, err) +} + +func TestNormalizeMiddleManagerDrainConfig(t *testing.T) { + assert.Equal(t, middleManagerDrainConfig{ + DrainTimeout: defaultMiddleManagerDrainTimeout, + PodReadyTimeout: defaultMiddleManagerPodReadyTimeout, + }, normalizeMiddleManagerDrainConfig(nil)) + + config := normalizeMiddleManagerDrainConfig(&v1alpha1.MiddleManagerDrainStrategy{ + DrainTimeout: metav1.Duration{Duration: 2 * time.Hour}, + PodReadyTimeout: metav1.Duration{Duration: 10 * time.Minute}, + }) + assert.Equal(t, 2*time.Hour, config.DrainTimeout) + assert.Equal(t, 10*time.Minute, config.PodReadyTimeout) +} + +func TestMiddleManagerDrainStatefulSetUpdaterFnBlocksRollout(t *testing.T) { + replicas := int32(3) + sts := &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{Name: "druid-mm"}, + Spec: appsv1.StatefulSetSpec{ + Replicas: &replicas, + UpdateStrategy: appsv1.StatefulSetUpdateStrategy{ + Type: appsv1.RollingUpdateStatefulSetStrategyType, + }, + }, + } + + middleManagerDrainStatefulSetUpdaterFn(nil, sts) + + require.NotNil(t, sts.Spec.UpdateStrategy.RollingUpdate) + require.NotNil(t, sts.Spec.UpdateStrategy.RollingUpdate.Partition) + assert.Equal(t, replicas, *sts.Spec.UpdateStrategy.RollingUpdate.Partition) +} + +func TestContinueMiddleManagerDrainCycleBlocksOnPodReadyTimeout(t *testing.T) { + require.NoError(t, v1alpha1.AddToScheme(scheme.Scheme)) + + drd := &v1alpha1.Druid{ + ObjectMeta: metav1.ObjectMeta{ + Name: "druid", + Namespace: "druid", + }, + Status: v1alpha1.DruidClusterStatus{ + MiddleManagerDrain: &v1alpha1.MiddleManagerDrainStatus{ + StatefulSet: "druid-mm", + Phase: middleManagerDrainPhaseWaitingForPod, + PodName: "druid-mm-0", + }, + }, + } + sts := &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: "druid-mm", + Namespace: "druid", + }, + Status: appsv1.StatefulSetStatus{ + CurrentRevision: "old", + UpdateRevision: "new", + }, + } + + k8sClient := fake.NewClientBuilder(). + WithScheme(scheme.Scheme). + WithObjects(drd, sts). + WithStatusSubresource(drd). + Build() + + state := &middleManagerDrainState{ + Phase: middleManagerDrainPhaseWaitingForPod, + PodName: "druid-mm-0", + PodOrdinal: 0, + LastUpdateTime: time.Now().Add(-31 * time.Minute), + } + + err := continueMiddleManagerDrainCycle( + context.Background(), + k8sClient, + drd, + sts, + &fakeMiddleManagerDruidAPI{}, + state, + 8091, + middleManagerDrainConfig{DrainTimeout: time.Hour, PodReadyTimeout: 30 * time.Minute}, + EmitEventFuncs{record.NewFakeRecorder(10)}, + ) + + require.Error(t, err) + blockedState, exists := getMiddleManagerDrainState("druid", "druid", "druid-mm") + require.True(t, exists) + assert.Equal(t, middleManagerDrainPhaseBlocked, blockedState.Phase) + + var updated v1alpha1.Druid + require.NoError(t, k8sClient.Get(context.Background(), clientObjectKey("druid", "druid"), &updated)) + require.NotNil(t, updated.Status.MiddleManagerDrain) + assert.Equal(t, middleManagerDrainPhaseBlocked, updated.Status.MiddleManagerDrain.Phase) + assert.Contains(t, updated.Status.MiddleManagerDrain.Message, "Timed out") +} + +func clientObjectKey(namespace, name string) client.ObjectKey { + return client.ObjectKey{Namespace: namespace, Name: name} +} diff --git a/controllers/druid/status.go b/controllers/druid/status.go index da940e8..357a1d2 100644 --- a/controllers/druid/status.go +++ b/controllers/druid/status.go @@ -67,7 +67,9 @@ func druidClusterStatusPatcher(ctx context.Context, sdk client.Client, updatedSt if err != nil { return fmt.Errorf("failed to serialize status patch to bytes: %v", err) } - _ = writers.Patch(ctx, sdk, m, m, true, client.RawPatch(types.MergePatchType, patchBytes), emitEvent) + if err := writers.Patch(ctx, sdk, m, m, true, client.RawPatch(types.MergePatchType, patchBytes), emitEvent); err != nil { + return err + } } return nil } diff --git a/controllers/druid/suite_test.go b/controllers/druid/suite_test.go index 81d311b..0188334 100644 --- a/controllers/druid/suite_test.go +++ b/controllers/druid/suite_test.go @@ -99,7 +99,8 @@ var _ = BeforeSuite(func() { Expect(k8sClient).NotTo(BeNil()) k8sManager, err := ctrl.NewManager(cfg, ctrl.Options{ - Scheme: scheme.Scheme, + Scheme: scheme.Scheme, + MetricsBindAddress: "0", }) Expect(err).ToNot(HaveOccurred()) diff --git a/docs/api_specifications/druid.md b/docs/api_specifications/druid.md index 687e50f..f0ad39c 100644 --- a/docs/api_specifications/druid.md +++ b/docs/api_specifications/druid.md @@ -1,21 +1,3 @@ -
Packages:
Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +“License”); you may not use this file except in compliance +with the License. You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +“AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License.
Resource Types:middleManagerDrainStrategyMiddleManagerDrainStrategy enables operator-managed draining before +MiddleManager StatefulSet pods are rolled. If nil, MiddleManagers use the +standard StatefulSet rolling update behavior.
+defaultProbesmiddleManagerDrainmiddleManagerDrainStrategyMiddleManagerDrainStrategy enables operator-managed draining before +MiddleManager StatefulSet pods are rolled. If nil, MiddleManagers use the +standard StatefulSet rolling update behavior.
+defaultProbes+(Appears on: +DruidClusterStatus) +
+MiddleManagerDrainStatus reports an in-progress MiddleManager drain rollout.
+| Field | +Description | +
|---|---|
+statefulSet+ +string + + |
++ | +
+phase+ +string + + |
++ | +
+podName+ +string + + |
++ | +
+podOrdinal+ +int32 + + |
++ | +
+oldPodUID+ +string + + |
++ | +
+lastTransitionTime+ + +Kubernetes meta/v1.Time + + + |
++ | +
+message+ +string + + |
++ | +
+(Appears on: +DruidSpec) +
+MiddleManagerDrainStrategy configures operator-managed draining before a +MiddleManager StatefulSet pod is rolled to a new revision.
+| Field | +Description | +
|---|---|
+drainTimeout+ + +Kubernetes meta/v1.Duration + + + |
+
+(Optional)
+ DrainTimeout is the maximum time to wait for streaming ingestion tasks to +drain before allowing Kubernetes to replace the MiddleManager pod. + |
+
+podReadyTimeout+ + +Kubernetes meta/v1.Duration + + + |
+
+(Optional)
+ PodReadyTimeout is the maximum time to wait for Kubernetes to replace the +pod and for the replacement to become ready on the target StatefulSet revision. + |
+
diff --git a/docs/druid_cr.md b/docs/druid_cr.md index a082062..3184e13 100644 --- a/docs/druid_cr.md +++ b/docs/druid_cr.md @@ -31,6 +31,11 @@ spec: # more information in features.md and in druid documentation # http://druid.io/docs/latest/operations/rolling-updates.html rollingDeploy: true + # Optional: drain MiddleManager StatefulSet pods before replacing them during rolling updates. + # Omit this field to use standard StatefulSet rolling update behavior. + middleManagerDrainStrategy: + drainTimeout: 1h + podReadyTimeout: 30m # Image for druid, Required Key image: apache/druid:25.0.0 .... diff --git a/docs/features.md b/docs/features.md index 6f8b74b..2eab033 100644 --- a/docs/features.md +++ b/docs/features.md @@ -23,6 +23,7 @@ under the License. - [Finalizer in Druid CR](#finalizer-in-druid-cr) - [Deletion of Orphan PVCs](#deletion-of-orphan-pvcs) - [Rolling Deploy](#rolling-deploy) +- [MiddleManager Drain Strategy](#middlemanager-drain-strategy) - [Force Delete of Sts Pods](#force-delete-of-sts-pods) - [Horizontal Scaling of Druid Pods](#horizontal-scaling-of-druid-pods) - [Volume Expansion of Druid Pods Running As StatefulSets](#volume-expansion-of-druid-pods-running-as-statefulsets) @@ -68,6 +69,39 @@ Default updates are done in parallel. Since cluster creation does not require a in parallel anyway. To enable this feature, set `rollingDeploy: true` in the Druid CR. ⚠️ This feature is enabled by default. +## MiddleManager Drain Strategy +For streaming ingestion clusters, a regular StatefulSet rolling update can terminate MiddleManager pods while they still +own running Kafka or Kinesis indexing tasks. The operator can instead drain each MiddleManager pod before Kubernetes +replaces it. + +Enable this opt-in behavior by setting `spec.middleManagerDrainStrategy`. The field uses pointer presence: if the field +is omitted, MiddleManagers use the standard StatefulSet rolling update behavior. + +```yaml +spec: + rollingDeploy: true + middleManagerDrainStrategy: + drainTimeout: 1h + podReadyTimeout: 30m +``` + +When enabled, the operator: + +1. Blocks the MiddleManager StatefulSet rollout by setting the rolling update partition to the replica count. +2. Selects one outdated MiddleManager pod at a time, highest ordinal first. +3. Disables the worker through the Druid Overlord API. +4. Finds running streaming tasks with Druid SQL and triggers supervisor task-group handoff. +5. Waits for running Kafka/Kinesis tasks on that pod to drain, up to `drainTimeout`. +6. Lowers the StatefulSet partition for that pod and waits for the replacement pod to be ready on the new revision. +7. Re-enables the worker and proceeds to the next pod on the following reconcile. + +The strategy is supported only for MiddleManager nodes running as StatefulSets. If configured for a MiddleManager +Deployment, the operator logs a warning and uses the normal Deployment rollout path. + +This feature requires the operator to reach the Druid Router service and use the Druid API credentials configured in +`spec.auth`, when authentication is enabled. If the Router service cannot be discovered, the rollout is blocked and the +error is surfaced by the reconcile loop. + ## Force Delete of Sts Pods During upgradeS, if THE StatefulSet is set to `OrderedReady` - the StatefulSet controller will not recover from crash-loopback state. The issues is referenced [here](https://github.com/kubernetes/kubernetes/issues/67250). diff --git a/pkg/druidapi/druidapi.go b/pkg/druidapi/druidapi.go index 3cf4749..028f87c 100644 --- a/pkg/druidapi/druidapi.go +++ b/pkg/druidapi/druidapi.go @@ -141,6 +141,18 @@ func MakePath(baseURL, componentType, apiType string, additionalPaths ...string) return u.String() } +// MakeSQLPath constructs the path for Druid's SQL API. +func MakeSQLPath(baseURL string) string { + u, err := url.Parse(baseURL) + if err != nil { + fmt.Println("Error parsing URL:", err) + return "" + } + + u.Path = path.Join("druid", "v2", "sql") + return u.String() +} + // GetRouterSvcUrl retrieves the URL of the Druid router service. // Parameters: // diff --git a/pkg/druidapi/druidapi_test.go b/pkg/druidapi/druidapi_test.go index ebcb6db..b7d80e6 100644 --- a/pkg/druidapi/druidapi_test.go +++ b/pkg/druidapi/druidapi_test.go @@ -182,3 +182,36 @@ func TestMakePath(t *testing.T) { }) } } + +func TestMakeSQLPath(t *testing.T) { + tests := []struct { + name string + baseURL string + expected string + }{ + { + name: "RouterService", + baseURL: "http://example-druid-service", + expected: "http://example-druid-service/druid/v2/sql", + }, + { + name: "BaseURLWithPath", + baseURL: "http://example-druid-service/base", + expected: "http://example-druid-service/druid/v2/sql", + }, + { + name: "EmptyBaseURL", + baseURL: "", + expected: "druid/v2/sql", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := MakeSQLPath(tt.baseURL) + if actual != tt.expected { + t.Errorf("MakeSQLPath() = %v, expected %v", actual, tt.expected) + } + }) + } +}