From 7ffc3771a6a04f9d22956e01ab00b1ba5fa60d82 Mon Sep 17 00:00:00 2001
From: Aru Raghuwanshi
Date: Sun, 10 May 2026 02:35:23 +0530
Subject: [PATCH] Add MiddleManager drain rollout strategy
Introduce an opt-in CRD strategy that drains MiddleManager StatefulSet pods before rolling them, with status visibility, timeout handling, and upstream Druid API auth plumbing.
---
apis/druid/v1alpha1/druid_types.go | 54 +-
apis/druid/v1alpha1/zz_generated.deepcopy.go | 43 ++
chart/crds/druid.apache.org_druids.yaml | 40 +
config/crd/bases/druid.apache.org_druids.yaml | 40 +
controllers/druid/handler.go | 28 +-
controllers/druid/middle_manager_drain.go | 697 ++++++++++++++++++
.../druid/middle_manager_drain_test.go | 236 ++++++
controllers/druid/status.go | 4 +-
controllers/druid/suite_test.go | 3 +-
docs/api_specifications/druid.md | 221 +++++-
docs/druid_cr.md | 5 +
docs/features.md | 34 +
pkg/druidapi/druidapi.go | 12 +
pkg/druidapi/druidapi_test.go | 33 +
14 files changed, 1419 insertions(+), 31 deletions(-)
create mode 100644 controllers/druid/middle_manager_drain.go
create mode 100644 controllers/druid/middle_manager_drain_test.go
diff --git a/apis/druid/v1alpha1/druid_types.go b/apis/druid/v1alpha1/druid_types.go
index 34c68a1d..34132802 100644
--- a/apis/druid/v1alpha1/druid_types.go
+++ b/apis/druid/v1alpha1/druid_types.go
@@ -93,6 +93,22 @@ type AdditionalContainer struct {
EnvFrom []v1.EnvFromSource `json:"envFrom,omitempty"`
}
+// MiddleManagerDrainStrategy configures operator-managed draining before a
+// MiddleManager StatefulSet pod is rolled to a new revision.
+type MiddleManagerDrainStrategy struct {
+ // DrainTimeout is the maximum time to wait for streaming ingestion tasks to
+ // drain before allowing Kubernetes to replace the MiddleManager pod.
+ // +optional
+ // +kubebuilder:default:="1h"
+ DrainTimeout metav1.Duration `json:"drainTimeout,omitempty"`
+
+ // PodReadyTimeout is the maximum time to wait for Kubernetes to replace the
+ // pod and for the replacement to become ready on the target StatefulSet revision.
+ // +optional
+ // +kubebuilder:default:="30m"
+ PodReadyTimeout metav1.Duration `json:"podReadyTimeout,omitempty"`
+}
+
// DruidSpec defines the desired state of the Druid cluster.
type DruidSpec struct {
@@ -270,6 +286,12 @@ type DruidSpec struct {
// +kubebuilder:default:=true
RollingDeploy bool `json:"rollingDeploy"`
+ // MiddleManagerDrainStrategy enables operator-managed draining before
+ // MiddleManager StatefulSet pods are rolled. If nil, MiddleManagers use the
+ // standard StatefulSet rolling update behavior.
+ // +optional
+ MiddleManagerDrainStrategy *MiddleManagerDrainStrategy `json:"middleManagerDrainStrategy,omitempty"`
+
// DefaultProbes If set to true this will add default probes (liveness / readiness / startup) for all druid components
// but it won't override existing probes
// +optional
@@ -570,20 +592,32 @@ type DruidNodeTypeStatus struct {
Reason string `json:"reason,omitempty"`
}
+// MiddleManagerDrainStatus reports an in-progress MiddleManager drain rollout.
+type MiddleManagerDrainStatus struct {
+ StatefulSet string `json:"statefulSet,omitempty"`
+ Phase string `json:"phase,omitempty"`
+ PodName string `json:"podName,omitempty"`
+ PodOrdinal int32 `json:"podOrdinal,omitempty"`
+ OldPodUID string `json:"oldPodUID,omitempty"`
+ LastTransitionTime metav1.Time `json:"lastTransitionTime,omitempty"`
+ Message string `json:"message,omitempty"`
+}
+
// DruidClusterStatus Defines the observed state of Druid.
type DruidClusterStatus struct {
// INSERT ADDITIONAL STATUS FIELD - define observed state of cluster
// Important: Run "make" to regenerate code after modifying this file
- DruidNodeStatus DruidNodeTypeStatus `json:"druidNodeStatus,omitempty"`
- StatefulSets []string `json:"statefulSets,omitempty"`
- Deployments []string `json:"deployments,omitempty"`
- Services []string `json:"services,omitempty"`
- ConfigMaps []string `json:"configMaps,omitempty"`
- PodDisruptionBudgets []string `json:"podDisruptionBudgets,omitempty"`
- Ingress []string `json:"ingress,omitempty"`
- HPAutoScalers []string `json:"hpAutoscalers,omitempty"`
- Pods []string `json:"pods,omitempty"`
- PersistentVolumeClaims []string `json:"persistentVolumeClaims,omitempty"`
+ DruidNodeStatus DruidNodeTypeStatus `json:"druidNodeStatus,omitempty"`
+ StatefulSets []string `json:"statefulSets,omitempty"`
+ Deployments []string `json:"deployments,omitempty"`
+ Services []string `json:"services,omitempty"`
+ ConfigMaps []string `json:"configMaps,omitempty"`
+ PodDisruptionBudgets []string `json:"podDisruptionBudgets,omitempty"`
+ Ingress []string `json:"ingress,omitempty"`
+ HPAutoScalers []string `json:"hpAutoscalers,omitempty"`
+ Pods []string `json:"pods,omitempty"`
+ PersistentVolumeClaims []string `json:"persistentVolumeClaims,omitempty"`
+ MiddleManagerDrain *MiddleManagerDrainStatus `json:"middleManagerDrain,omitempty"`
}
// Druid is the Schema for the druids API.
diff --git a/apis/druid/v1alpha1/zz_generated.deepcopy.go b/apis/druid/v1alpha1/zz_generated.deepcopy.go
index 93eff590..bbfd644b 100644
--- a/apis/druid/v1alpha1/zz_generated.deepcopy.go
+++ b/apis/druid/v1alpha1/zz_generated.deepcopy.go
@@ -181,6 +181,11 @@ func (in *DruidClusterStatus) DeepCopyInto(out *DruidClusterStatus) {
*out = make([]string, len(*in))
copy(*out, *in)
}
+ if in.MiddleManagerDrain != nil {
+ in, out := &in.MiddleManagerDrain, &out.MiddleManagerDrain
+ *out = new(MiddleManagerDrainStatus)
+ (*in).DeepCopyInto(*out)
+ }
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DruidClusterStatus.
@@ -691,6 +696,11 @@ func (in *DruidSpec) DeepCopyInto(out *DruidSpec) {
(*in)[i].DeepCopyInto(&(*out)[i])
}
}
+ if in.MiddleManagerDrainStrategy != nil {
+ in, out := &in.MiddleManagerDrainStrategy, &out.MiddleManagerDrainStrategy
+ *out = new(MiddleManagerDrainStrategy)
+ **out = **in
+ }
if in.Zookeeper != nil {
in, out := &in.Zookeeper, &out.Zookeeper
*out = new(ZookeeperSpec)
@@ -769,6 +779,39 @@ func (in *MetadataStoreSpec) DeepCopy() *MetadataStoreSpec {
return out
}
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in *MiddleManagerDrainStatus) DeepCopyInto(out *MiddleManagerDrainStatus) {
+ *out = *in
+ in.LastTransitionTime.DeepCopyInto(&out.LastTransitionTime)
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MiddleManagerDrainStatus.
+func (in *MiddleManagerDrainStatus) DeepCopy() *MiddleManagerDrainStatus {
+ if in == nil {
+ return nil
+ }
+ out := new(MiddleManagerDrainStatus)
+ in.DeepCopyInto(out)
+ return out
+}
+
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in *MiddleManagerDrainStrategy) DeepCopyInto(out *MiddleManagerDrainStrategy) {
+ *out = *in
+ out.DrainTimeout = in.DrainTimeout
+ out.PodReadyTimeout = in.PodReadyTimeout
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MiddleManagerDrainStrategy.
+func (in *MiddleManagerDrainStrategy) DeepCopy() *MiddleManagerDrainStrategy {
+ if in == nil {
+ return nil
+ }
+ out := new(MiddleManagerDrainStrategy)
+ in.DeepCopyInto(out)
+ return out
+}
+
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *ZookeeperSpec) DeepCopyInto(out *ZookeeperSpec) {
*out = *in
diff --git a/chart/crds/druid.apache.org_druids.yaml b/chart/crds/druid.apache.org_druids.yaml
index e2971b36..be95dccc 100644
--- a/chart/crds/druid.apache.org_druids.yaml
+++ b/chart/crds/druid.apache.org_druids.yaml
@@ -1980,6 +1980,25 @@ spec:
stastd documentation is described in the following documentation:
https://druid.apache.org/docs/latest/development/extensions-contrib/statsd.html
type: string
+ middleManagerDrainStrategy:
+ description: |-
+ MiddleManagerDrainStrategy enables operator-managed draining before
+ MiddleManager StatefulSet pods are rolled. If nil, MiddleManagers use the
+ standard StatefulSet rolling update behavior.
+ properties:
+ drainTimeout:
+ default: 1h
+ description: |-
+ DrainTimeout is the maximum time to wait for streaming ingestion tasks to
+ drain before allowing Kubernetes to replace the MiddleManager pod.
+ type: string
+ podReadyTimeout:
+ default: 30m
+ description: |-
+ PodReadyTimeout is the maximum time to wait for Kubernetes to replace the
+ pod and for the replacement to become ready on the target StatefulSet revision.
+ type: string
+ type: object
nodeSelector:
additionalProperties:
type: string
@@ -11877,6 +11896,27 @@ spec:
items:
type: string
type: array
+ middleManagerDrain:
+ description: MiddleManagerDrainStatus reports an in-progress MiddleManager
+ drain rollout.
+ properties:
+ lastTransitionTime:
+ format: date-time
+ type: string
+ message:
+ type: string
+ oldPodUID:
+ type: string
+ phase:
+ type: string
+ podName:
+ type: string
+ podOrdinal:
+ format: int32
+ type: integer
+ statefulSet:
+ type: string
+ type: object
persistentVolumeClaims:
items:
type: string
diff --git a/config/crd/bases/druid.apache.org_druids.yaml b/config/crd/bases/druid.apache.org_druids.yaml
index e2971b36..be95dccc 100644
--- a/config/crd/bases/druid.apache.org_druids.yaml
+++ b/config/crd/bases/druid.apache.org_druids.yaml
@@ -1980,6 +1980,25 @@ spec:
stastd documentation is described in the following documentation:
https://druid.apache.org/docs/latest/development/extensions-contrib/statsd.html
type: string
+ middleManagerDrainStrategy:
+ description: |-
+ MiddleManagerDrainStrategy enables operator-managed draining before
+ MiddleManager StatefulSet pods are rolled. If nil, MiddleManagers use the
+ standard StatefulSet rolling update behavior.
+ properties:
+ drainTimeout:
+ default: 1h
+ description: |-
+ DrainTimeout is the maximum time to wait for streaming ingestion tasks to
+ drain before allowing Kubernetes to replace the MiddleManager pod.
+ type: string
+ podReadyTimeout:
+ default: 30m
+ description: |-
+ PodReadyTimeout is the maximum time to wait for Kubernetes to replace the
+ pod and for the replacement to become ready on the target StatefulSet revision.
+ type: string
+ type: object
nodeSelector:
additionalProperties:
type: string
@@ -11877,6 +11896,27 @@ spec:
items:
type: string
type: array
+ middleManagerDrain:
+ description: MiddleManagerDrainStatus reports an in-progress MiddleManager
+ drain rollout.
+ properties:
+ lastTransitionTime:
+ format: date-time
+ type: string
+ message:
+ type: string
+ oldPodUID:
+ type: string
+ phase:
+ type: string
+ podName:
+ type: string
+ podOrdinal:
+ format: int32
+ type: integer
+ statefulSet:
+ type: string
+ type: object
persistentVolumeClaims:
items:
type: string
diff --git a/controllers/druid/handler.go b/controllers/druid/handler.go
index b236ef44..b4499ccf 100644
--- a/controllers/druid/handler.go
+++ b/controllers/druid/handler.go
@@ -147,6 +147,12 @@ func deployDruidCluster(ctx context.Context, sdk client.Client, m *v1alpha1.Drui
nodeSpec.Ports = append(nodeSpec.Ports, v1.ContainerPort{ContainerPort: nodeSpec.DruidPort, Name: "druid-port"})
+ if m.Spec.MiddleManagerDrainStrategy != nil && nodeSpec.NodeType == middleManager && nodeSpec.Kind == "Deployment" {
+ logger.Info("MiddleManager drain strategy is only supported for StatefulSet workloads; using standard Deployment rollout",
+ "nodeSpecUniqueStr", nodeSpecUniqueStr,
+ "namespace", m.Namespace)
+ }
+
if nodeSpec.Kind == "Deployment" {
if deployCreateUpdateStatus, err := sdkCreateOrUpdateAsNeeded(ctx, sdk,
func() (object, error) {
@@ -183,13 +189,22 @@ func deployDruidCluster(ctx context.Context, sdk client.Client, m *v1alpha1.Drui
}
}
+ if m.Generation > 1 && nodeSpec.NodeType == middleManager && m.Spec.MiddleManagerDrainStrategy == nil {
+ cleanupStaleMiddleManagerDrainState(ctx, sdk, m, nodeSpecUniqueStr, emitEvents)
+ }
+
+ stsUpdaterFn := noopUpdaterFn
+ if m.Generation > 1 && nodeSpec.NodeType == middleManager && m.Spec.MiddleManagerDrainStrategy != nil {
+ stsUpdaterFn = middleManagerDrainStatefulSetUpdaterFn
+ }
+
// Create/Update StatefulSet
if stsCreateUpdateStatus, err := sdkCreateOrUpdateAsNeeded(ctx, sdk,
func() (object, error) {
return makeStatefulSet(&nodeSpec, m, lm, nodeSpecUniqueStr, fmt.Sprintf("%s-%s", commonConfigSHA, nodeConfigSHA), firstServiceName)
},
func() object { return &appsv1.StatefulSet{} },
- statefulSetIsEquals, noopUpdaterFn, m, statefulSetNames, emitEvents); err != nil {
+ statefulSetIsEquals, stsUpdaterFn, m, statefulSetNames, emitEvents); err != nil {
return err
} else if m.Spec.RollingDeploy {
@@ -208,6 +223,16 @@ func deployDruidCluster(ctx context.Context, sdk client.Client, m *v1alpha1.Drui
//Check StatefulSet rolling update status, if in-progress then stop here
done, err := isObjFullyDeployed(ctx, sdk, nodeSpec, nodeSpecUniqueStr, m, func() object { return &appsv1.StatefulSet{} }, emitEvents)
if !done {
+ if nodeSpec.NodeType == middleManager && m.Spec.MiddleManagerDrainStrategy != nil {
+ if err := processMiddleManagerRollingRestart(ctx, sdk, m, nodeSpecUniqueStr, m.Spec.MiddleManagerDrainStrategy, emitEvents); err != nil {
+ return err
+ }
+ // The drain state machine owns the StatefulSet partition while a
+ // MiddleManager rollout is in progress. Return immediately so
+ // sdkCreateOrUpdateAsNeeded is not re-entered in this reconcile
+ // and cannot overwrite the partition selected for the current phase.
+ return nil
+ }
return err
}
}
@@ -273,6 +298,7 @@ func deployDruidCluster(ctx context.Context, sdk client.Client, m *v1alpha1.Drui
//update status and delete unwanted resources
updatedStatus := v1alpha1.DruidClusterStatus{}
+ updatedStatus.MiddleManagerDrain = m.Status.MiddleManagerDrain
updatedStatus.StatefulSets = deleteUnusedResources(ctx, sdk, m, statefulSetNames, ls,
func() objectList { return &appsv1.StatefulSetList{} },
diff --git a/controllers/druid/middle_manager_drain.go b/controllers/druid/middle_manager_drain.go
new file mode 100644
index 00000000..712cba59
--- /dev/null
+++ b/controllers/druid/middle_manager_drain.go
@@ -0,0 +1,697 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+*/
+package druid
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+ "regexp"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/apache/druid-operator/apis/druid/v1alpha1"
+ druidapi "github.com/apache/druid-operator/pkg/druidapi"
+ internalhttp "github.com/apache/druid-operator/pkg/http"
+ appsv1 "k8s.io/api/apps/v1"
+ v1 "k8s.io/api/core/v1"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/apimachinery/pkg/types"
+ "sigs.k8s.io/controller-runtime/pkg/client"
+)
+
+const (
+ middleManagerDrainPhaseDraining = "Draining"
+ middleManagerDrainPhaseWaitingForPod = "WaitingForPod"
+ middleManagerDrainPhaseBlocked = "Blocked"
+
+ defaultMiddleManagerDrainTimeout = time.Hour
+ defaultMiddleManagerPodReadyTimeout = 30 * time.Minute
+ statefulSetPartitionResetValue = int32(0)
+)
+
+var (
+ middleManagerDrainStates = sync.Map{}
+ workerHostnamePattern = regexp.MustCompile(`^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$`)
+)
+
+type middleManagerDrainState struct {
+ Phase string
+ PodName string
+ PodOrdinal int32
+ OldPodUID string
+ LastUpdateTime time.Time
+}
+
+type middleManagerDrainConfig struct {
+ DrainTimeout time.Duration
+ PodReadyTimeout time.Duration
+}
+
+type middleManagerDruidAPI interface {
+ DisableWorker(workerHost string) error
+ EnableWorker(workerHost string) error
+ GetTaskPayload(taskID string) (*taskPayloadResponse, error)
+ TriggerTaskGroupHandoff(supervisorID string, taskGroupIDs []int) error
+ ExecuteSQL(query string) ([]byte, error)
+}
+
+type middleManagerDruidHTTPAPI struct {
+ baseURL string
+ httpClient internalhttp.DruidHTTP
+}
+
+type druidSQLRequest struct {
+ Query string `json:"query"`
+}
+
+type runningTaskInfo struct {
+ TaskID string `json:"task_id"`
+ DataSource string `json:"datasource"`
+ Type string `json:"type"`
+}
+
+type taskPayloadResponse struct {
+ Task string `json:"task"`
+ Payload struct {
+ DataSource string `json:"dataSource"`
+ IOConfig struct {
+ TaskGroupID *int `json:"taskGroupId"`
+ } `json:"ioConfig"`
+ } `json:"payload"`
+}
+
+type taskGroupHandoffRequest struct {
+ TaskGroupIDs []int `json:"taskGroupIds"`
+}
+
+func newMiddleManagerDruidAPI(ctx context.Context, sdk client.Client, drd *v1alpha1.Druid) (middleManagerDruidAPI, error) {
+ routerURL, err := druidapi.GetRouterSvcUrl(drd.Namespace, drd.Name, sdk)
+ if err != nil {
+ return nil, fmt.Errorf("failed to discover Druid router service: %w", err)
+ }
+
+ basicAuth, err := druidapi.GetAuthCreds(ctx, sdk, drd.Spec.Auth)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get Druid API credentials: %w", err)
+ }
+
+ return &middleManagerDruidHTTPAPI{
+ baseURL: routerURL,
+ httpClient: internalhttp.NewHTTPClient(
+ &http.Client{},
+ &internalhttp.Auth{BasicAuth: basicAuth},
+ ),
+ }, nil
+}
+
+func (c *middleManagerDruidHTTPAPI) DisableWorker(workerHost string) error {
+ path := fmt.Sprintf("%s/%s/disable", druidapi.MakePath(c.baseURL, "indexer", "worker"), url.PathEscape(workerHost))
+ resp, err := c.httpClient.Do(http.MethodPost, path, nil)
+ if err != nil {
+ return fmt.Errorf("failed to call disable API for worker %q: %w", workerHost, err)
+ }
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("disable API returned status %d for worker %q: %s", resp.StatusCode, workerHost, resp.ResponseBody)
+ }
+ return nil
+}
+
+func (c *middleManagerDruidHTTPAPI) EnableWorker(workerHost string) error {
+ path := fmt.Sprintf("%s/%s/enable", druidapi.MakePath(c.baseURL, "indexer", "worker"), url.PathEscape(workerHost))
+ resp, err := c.httpClient.Do(http.MethodPost, path, nil)
+ if err != nil {
+ return fmt.Errorf("failed to call enable API for worker %q: %w", workerHost, err)
+ }
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("enable API returned status %d for worker %q: %s", resp.StatusCode, workerHost, resp.ResponseBody)
+ }
+ return nil
+}
+
+func (c *middleManagerDruidHTTPAPI) GetTaskPayload(taskID string) (*taskPayloadResponse, error) {
+ path := fmt.Sprintf("%s/%s", druidapi.MakePath(c.baseURL, "indexer", "task"), url.PathEscape(taskID))
+ resp, err := c.httpClient.Do(http.MethodGet, path, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to fetch task payload for %q: %w", taskID, err)
+ }
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("task payload API returned status %d for %q: %s", resp.StatusCode, taskID, resp.ResponseBody)
+ }
+
+ var payload taskPayloadResponse
+ if err := json.Unmarshal([]byte(resp.ResponseBody), &payload); err != nil {
+ return nil, fmt.Errorf("failed to decode task payload for %q: %w", taskID, err)
+ }
+ return &payload, nil
+}
+
+func (c *middleManagerDruidHTTPAPI) TriggerTaskGroupHandoff(supervisorID string, taskGroupIDs []int) error {
+ path := fmt.Sprintf("%s/%s/taskGroups/handoff", druidapi.MakePath(c.baseURL, "indexer", "supervisor"), url.PathEscape(supervisorID))
+ reqBody, err := json.Marshal(taskGroupHandoffRequest{TaskGroupIDs: taskGroupIDs})
+ if err != nil {
+ return fmt.Errorf("failed to marshal handoff request for %q: %w", supervisorID, err)
+ }
+
+ resp, err := c.httpClient.Do(http.MethodPost, path, reqBody)
+ if err != nil {
+ return fmt.Errorf("failed to trigger handoff for supervisor %q: %w", supervisorID, err)
+ }
+ if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
+ return fmt.Errorf("handoff API returned status %d for supervisor %q: %s", resp.StatusCode, supervisorID, resp.ResponseBody)
+ }
+ return nil
+}
+
+func (c *middleManagerDruidHTTPAPI) ExecuteSQL(query string) ([]byte, error) {
+ reqBody, err := json.Marshal(druidSQLRequest{Query: query})
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal SQL request: %w", err)
+ }
+
+ resp, err := c.httpClient.Do(http.MethodPost, druidapi.MakeSQLPath(c.baseURL), reqBody)
+ if err != nil {
+ return nil, fmt.Errorf("failed to execute Druid SQL: %w", err)
+ }
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("Druid SQL API returned status %d: %s", resp.StatusCode, resp.ResponseBody)
+ }
+ return []byte(resp.ResponseBody), nil
+}
+
+func normalizeMiddleManagerDrainConfig(strategy *v1alpha1.MiddleManagerDrainStrategy) middleManagerDrainConfig {
+ config := middleManagerDrainConfig{
+ DrainTimeout: defaultMiddleManagerDrainTimeout,
+ PodReadyTimeout: defaultMiddleManagerPodReadyTimeout,
+ }
+ if strategy == nil {
+ return config
+ }
+ if strategy.DrainTimeout.Duration > 0 {
+ config.DrainTimeout = strategy.DrainTimeout.Duration
+ }
+ if strategy.PodReadyTimeout.Duration > 0 {
+ config.PodReadyTimeout = strategy.PodReadyTimeout.Duration
+ }
+ return config
+}
+
+func middleManagerDrainStateKey(namespace, druidCR, statefulSetName string) string {
+ return fmt.Sprintf("%s/%s/%s", namespace, druidCR, statefulSetName)
+}
+
+func getMiddleManagerDrainState(namespace, druidCR, statefulSetName string) (*middleManagerDrainState, bool) {
+ key := middleManagerDrainStateKey(namespace, druidCR, statefulSetName)
+ if value, exists := middleManagerDrainStates.Load(key); exists {
+ return value.(*middleManagerDrainState), true
+ }
+ return nil, false
+}
+
+func loadMiddleManagerDrainState(drd *v1alpha1.Druid, statefulSetName string) (*middleManagerDrainState, bool) {
+ if state, exists := getMiddleManagerDrainState(drd.Namespace, drd.Name, statefulSetName); exists {
+ return state, true
+ }
+
+ status := drd.Status.MiddleManagerDrain
+ if status == nil || status.StatefulSet != statefulSetName || status.Phase == "" || status.PodName == "" {
+ return nil, false
+ }
+
+ state := &middleManagerDrainState{
+ Phase: status.Phase,
+ PodName: status.PodName,
+ PodOrdinal: status.PodOrdinal,
+ OldPodUID: status.OldPodUID,
+ LastUpdateTime: status.LastTransitionTime.Time,
+ }
+ if state.LastUpdateTime.IsZero() {
+ state.LastUpdateTime = time.Now()
+ }
+ middleManagerDrainStates.Store(middleManagerDrainStateKey(drd.Namespace, drd.Name, statefulSetName), state)
+ return state, true
+}
+
+func setMiddleManagerDrainState(ctx context.Context, sdk client.Client, drd *v1alpha1.Druid, statefulSetName string, state *middleManagerDrainState, message string, emitEvent EventEmitter) error {
+ if state.LastUpdateTime.IsZero() {
+ state.LastUpdateTime = time.Now()
+ }
+ middleManagerDrainStates.Store(middleManagerDrainStateKey(drd.Namespace, drd.Name, statefulSetName), state)
+
+ return patchMiddleManagerDrainStatus(ctx, sdk, drd, &v1alpha1.MiddleManagerDrainStatus{
+ StatefulSet: statefulSetName,
+ Phase: state.Phase,
+ PodName: state.PodName,
+ PodOrdinal: state.PodOrdinal,
+ OldPodUID: state.OldPodUID,
+ LastTransitionTime: metav1.NewTime(state.LastUpdateTime),
+ Message: message,
+ }, emitEvent)
+}
+
+func clearMiddleManagerDrainState(ctx context.Context, sdk client.Client, drd *v1alpha1.Druid, statefulSetName string, emitEvent EventEmitter) error {
+ middleManagerDrainStates.Delete(middleManagerDrainStateKey(drd.Namespace, drd.Name, statefulSetName))
+ if drd.Status.MiddleManagerDrain == nil || drd.Status.MiddleManagerDrain.StatefulSet != statefulSetName {
+ return nil
+ }
+ return patchMiddleManagerDrainStatus(ctx, sdk, drd, nil, emitEvent)
+}
+
+func patchMiddleManagerDrainStatus(ctx context.Context, sdk client.Client, drd *v1alpha1.Druid, status *v1alpha1.MiddleManagerDrainStatus, emitEvent EventEmitter) error {
+ updatedStatus := drd.Status
+ updatedStatus.MiddleManagerDrain = status
+ if err := druidClusterStatusPatcher(ctx, sdk, updatedStatus, drd, emitEvent); err != nil {
+ return err
+ }
+ drd.Status = updatedStatus
+ return nil
+}
+
+func updateStatefulSetPartition(ctx context.Context, sdk client.Client, statefulSetName, namespace string, partition int32) error {
+ var sts appsv1.StatefulSet
+ if err := sdk.Get(ctx, types.NamespacedName{Name: statefulSetName, Namespace: namespace}, &sts); err != nil {
+ return fmt.Errorf("failed to get StatefulSet %q in namespace %q: %w", statefulSetName, namespace, err)
+ }
+
+ if sts.Spec.UpdateStrategy.RollingUpdate == nil {
+ sts.Spec.UpdateStrategy.Type = appsv1.RollingUpdateStatefulSetStrategyType
+ sts.Spec.UpdateStrategy.RollingUpdate = &appsv1.RollingUpdateStatefulSetStrategy{}
+ }
+ if sts.Spec.UpdateStrategy.RollingUpdate.Partition != nil && *sts.Spec.UpdateStrategy.RollingUpdate.Partition == partition {
+ return nil
+ }
+
+ sts.Spec.UpdateStrategy.RollingUpdate.Partition = &partition
+ return sdk.Update(ctx, &sts)
+}
+
+func middleManagerDrainStatefulSetUpdaterFn(prev, curr object) {
+ currSts, ok := curr.(*appsv1.StatefulSet)
+ if !ok {
+ return
+ }
+
+ replicas := int32(1)
+ if currSts.Spec.Replicas != nil {
+ replicas = *currSts.Spec.Replicas
+ }
+ if currSts.Spec.UpdateStrategy.RollingUpdate == nil {
+ currSts.Spec.UpdateStrategy.Type = appsv1.RollingUpdateStatefulSetStrategyType
+ currSts.Spec.UpdateStrategy.RollingUpdate = &appsv1.RollingUpdateStatefulSetStrategy{}
+ }
+ currSts.Spec.UpdateStrategy.RollingUpdate.Partition = &replicas
+}
+
+func cleanupStaleMiddleManagerDrainState(ctx context.Context, sdk client.Client, drd *v1alpha1.Druid, statefulSetName string, emitEvent EventEmitter) {
+ state, hasState := loadMiddleManagerDrainState(drd, statefulSetName)
+ if !hasState {
+ return
+ }
+
+ if hasState {
+ api, err := newMiddleManagerDruidAPI(ctx, sdk, drd)
+ if err != nil {
+ logger.Error(err, "Failed to create Druid API client while cleaning stale MiddleManager drain state", "statefulSet", statefulSetName)
+ } else {
+ druidPort, portErr := getDruidPortFromStatefulSet(ctx, sdk, statefulSetName, drd.Namespace)
+ if portErr != nil {
+ logger.Error(portErr, "Failed to get Druid port while cleaning stale MiddleManager drain state", "statefulSet", statefulSetName)
+ } else {
+ workerHost := buildMiddleManagerWorkerHost(state.PodName, statefulSetName, drd.Namespace, druidPort)
+ if err := api.EnableWorker(workerHost); err != nil {
+ logger.Error(err, "Failed to re-enable MiddleManager while cleaning stale drain state", "pod", state.PodName, "statefulSet", statefulSetName)
+ }
+ }
+ }
+ }
+
+ if err := clearMiddleManagerDrainState(ctx, sdk, drd, statefulSetName, emitEvent); err != nil {
+ logger.Error(err, "Failed to clear stale MiddleManager drain status", "statefulSet", statefulSetName)
+ }
+ if err := updateStatefulSetPartition(ctx, sdk, statefulSetName, drd.Namespace, statefulSetPartitionResetValue); err != nil {
+ logger.Error(err, "Failed to reset StatefulSet partition while cleaning stale MiddleManager drain state", "statefulSet", statefulSetName)
+ }
+}
+
+func processMiddleManagerRollingRestart(ctx context.Context, sdk client.Client, drd *v1alpha1.Druid, statefulSetName string, strategy *v1alpha1.MiddleManagerDrainStrategy, emitEvent EventEmitter) error {
+ var sts appsv1.StatefulSet
+ if err := sdk.Get(ctx, types.NamespacedName{Name: statefulSetName, Namespace: drd.Namespace}, &sts); err != nil {
+ return fmt.Errorf("failed to get StatefulSet %q in namespace %q: %w", statefulSetName, drd.Namespace, err)
+ }
+
+ if sts.Status.CurrentRevision == sts.Status.UpdateRevision {
+ if err := clearMiddleManagerDrainState(ctx, sdk, drd, statefulSetName, emitEvent); err != nil {
+ return err
+ }
+ return updateStatefulSetPartition(ctx, sdk, statefulSetName, drd.Namespace, statefulSetPartitionResetValue)
+ }
+
+ config := normalizeMiddleManagerDrainConfig(strategy)
+ totalReplicas := int32(1)
+ if sts.Spec.Replicas != nil {
+ totalReplicas = *sts.Spec.Replicas
+ }
+
+ state, hasState := loadMiddleManagerDrainState(drd, statefulSetName)
+ if !hasState {
+ if err := updateStatefulSetPartition(ctx, sdk, statefulSetName, drd.Namespace, totalReplicas); err != nil {
+ return fmt.Errorf("failed to block MiddleManager StatefulSet rolling update: %w", err)
+ }
+ }
+
+ api, err := newMiddleManagerDruidAPI(ctx, sdk, drd)
+ if err != nil {
+ return err
+ }
+
+ druidPort, err := getDruidPortFromStatefulSet(ctx, sdk, statefulSetName, drd.Namespace)
+ if err != nil {
+ return err
+ }
+
+ if hasState {
+ return continueMiddleManagerDrainCycle(ctx, sdk, drd, &sts, api, state, druidPort, config, emitEvent)
+ }
+ return startMiddleManagerDrainCycle(ctx, sdk, drd, &sts, api, druidPort, emitEvent)
+}
+
+func startMiddleManagerDrainCycle(ctx context.Context, sdk client.Client, drd *v1alpha1.Druid, sts *appsv1.StatefulSet, api middleManagerDruidAPI, druidPort int32, emitEvent EventEmitter) error {
+ outdatedPods, err := getOutdatedMiddleManagerPods(ctx, sdk, sts.Name, sts.Namespace, sts.Status.CurrentRevision)
+ if err != nil {
+ return err
+ }
+ if len(outdatedPods) == 0 {
+ return nil
+ }
+
+ sortPodsDescending(outdatedPods)
+ targetPod := outdatedPods[0]
+ podOrdinal := extractPodOrdinal(targetPod.Name)
+ if podOrdinal < 0 {
+ return fmt.Errorf("could not extract ordinal from MiddleManager pod name %q", targetPod.Name)
+ }
+
+ workerHost := buildMiddleManagerWorkerHost(targetPod.Name, sts.Name, sts.Namespace, druidPort)
+ if targetPod.Labels["controller-revision-hash"] == sts.Status.UpdateRevision {
+ if err := api.EnableWorker(workerHost); err != nil {
+ logger.Error(err, "Failed to enable already-updated MiddleManager pod", "pod", targetPod.Name)
+ }
+ return nil
+ }
+
+ if err := drainMiddleManager(api, workerHost); err != nil {
+ return err
+ }
+
+ return setMiddleManagerDrainState(ctx, sdk, drd, sts.Name, &middleManagerDrainState{
+ Phase: middleManagerDrainPhaseDraining,
+ PodName: targetPod.Name,
+ PodOrdinal: podOrdinal,
+ }, "Drain initiated; waiting for streaming ingestion tasks to finish", emitEvent)
+}
+
+func continueMiddleManagerDrainCycle(ctx context.Context, sdk client.Client, drd *v1alpha1.Druid, sts *appsv1.StatefulSet, api middleManagerDruidAPI, state *middleManagerDrainState, druidPort int32, config middleManagerDrainConfig, emitEvent EventEmitter) error {
+ workerHost := buildMiddleManagerWorkerHost(state.PodName, sts.Name, sts.Namespace, druidPort)
+ elapsed := time.Since(state.LastUpdateTime)
+
+ switch state.Phase {
+ case middleManagerDrainPhaseDraining:
+ if elapsed < config.DrainTimeout {
+ drained, err := isMiddleManagerDrained(api, workerHost)
+ if err != nil {
+ logger.Error(err, "Failed to check MiddleManager drain status; will retry", "pod", state.PodName)
+ return nil
+ }
+ if !drained {
+ return setMiddleManagerDrainState(ctx, sdk, drd, sts.Name, state, "Waiting for streaming ingestion tasks to drain", emitEvent)
+ }
+ }
+
+ oldPodUID := ""
+ var oldPod v1.Pod
+ if err := sdk.Get(ctx, types.NamespacedName{Name: state.PodName, Namespace: sts.Namespace}, &oldPod); err == nil {
+ oldPodUID = string(oldPod.UID)
+ }
+ if err := updateStatefulSetPartition(ctx, sdk, sts.Name, sts.Namespace, state.PodOrdinal); err != nil {
+ return fmt.Errorf("failed to lower StatefulSet partition for pod %q: %w", state.PodName, err)
+ }
+ return setMiddleManagerDrainState(ctx, sdk, drd, sts.Name, &middleManagerDrainState{
+ Phase: middleManagerDrainPhaseWaitingForPod,
+ PodName: state.PodName,
+ PodOrdinal: state.PodOrdinal,
+ OldPodUID: oldPodUID,
+ }, "Drain complete; waiting for replacement pod to become ready", emitEvent)
+
+ case middleManagerDrainPhaseWaitingForPod:
+ if elapsed >= config.PodReadyTimeout {
+ blockedState := *state
+ blockedState.Phase = middleManagerDrainPhaseBlocked
+ blockedState.LastUpdateTime = time.Time{}
+ message := fmt.Sprintf("Timed out after %s waiting for replacement pod to become ready", config.PodReadyTimeout)
+ if err := setMiddleManagerDrainState(ctx, sdk, drd, sts.Name, &blockedState, message, emitEvent); err != nil {
+ return err
+ }
+ return fmt.Errorf("MiddleManager pod %q did not become ready before timeout %s", state.PodName, config.PodReadyTimeout)
+ }
+
+ var pod v1.Pod
+ if err := sdk.Get(ctx, types.NamespacedName{Name: state.PodName, Namespace: sts.Namespace}, &pod); err != nil {
+ return nil
+ }
+ if state.OldPodUID != "" && string(pod.UID) == state.OldPodUID {
+ return nil
+ }
+ if !isPodReady(pod) {
+ return nil
+ }
+ if pod.Labels["controller-revision-hash"] != sts.Status.UpdateRevision {
+ if err := clearMiddleManagerDrainState(ctx, sdk, drd, sts.Name, emitEvent); err != nil {
+ return err
+ }
+ return nil
+ }
+ if err := api.EnableWorker(workerHost); err != nil {
+ logger.Error(err, "Failed to re-enable MiddleManager pod; will retry", "pod", state.PodName)
+ return nil
+ }
+ return clearMiddleManagerDrainState(ctx, sdk, drd, sts.Name, emitEvent)
+
+ case middleManagerDrainPhaseBlocked:
+ message := ""
+ if drd.Status.MiddleManagerDrain != nil {
+ message = drd.Status.MiddleManagerDrain.Message
+ }
+ return fmt.Errorf("MiddleManager drain rollout is blocked for pod %q: %s", state.PodName, message)
+
+ default:
+ return clearMiddleManagerDrainState(ctx, sdk, drd, sts.Name, emitEvent)
+ }
+}
+
+func drainMiddleManager(api middleManagerDruidAPI, workerHost string) error {
+ if err := api.DisableWorker(workerHost); err != nil {
+ return fmt.Errorf("failed to disable MiddleManager: %w", err)
+ }
+
+ runningTasks, err := getRunningTasksFromSQL(api, workerHost)
+ if err != nil {
+ return fmt.Errorf("failed to get running tasks after disabling MiddleManager: %w", err)
+ }
+ if len(runningTasks) == 0 {
+ return nil
+ }
+
+ handoffs, err := resolveHandoffsForWorker(api, runningTasks)
+ if err != nil {
+ return err
+ }
+ for supervisorID, taskGroupIDs := range handoffs {
+ if len(taskGroupIDs) == 0 {
+ continue
+ }
+ if err := api.TriggerTaskGroupHandoff(supervisorID, taskGroupIDs); err != nil {
+ return fmt.Errorf("failed to trigger handoff for supervisor %q: %w", supervisorID, err)
+ }
+ }
+ return nil
+}
+
+func resolveHandoffsForWorker(api middleManagerDruidAPI, runningTaskIDs []string) (map[string][]int, error) {
+ supervisorToGroupIDs := map[string]map[int]bool{}
+ for _, taskID := range runningTaskIDs {
+ payload, err := api.GetTaskPayload(taskID)
+ if err != nil {
+ logger.Error(err, "Failed to fetch task payload while resolving MiddleManager handoff", "taskID", taskID)
+ continue
+ }
+ if payload.Payload.DataSource == "" || payload.Payload.IOConfig.TaskGroupID == nil {
+ continue
+ }
+ if supervisorToGroupIDs[payload.Payload.DataSource] == nil {
+ supervisorToGroupIDs[payload.Payload.DataSource] = map[int]bool{}
+ }
+ supervisorToGroupIDs[payload.Payload.DataSource][*payload.Payload.IOConfig.TaskGroupID] = true
+ }
+
+ result := map[string][]int{}
+ for supervisorID, groupIDSet := range supervisorToGroupIDs {
+ ids := make([]int, 0, len(groupIDSet))
+ for id := range groupIDSet {
+ ids = append(ids, id)
+ }
+ sort.Ints(ids)
+ result[supervisorID] = ids
+ }
+ return result, nil
+}
+
+func getRunningTasksFromSQL(api middleManagerDruidAPI, workerHost string) ([]string, error) {
+ hostname, err := validateWorkerHostnameForSQL(workerHost)
+ if err != nil {
+ return nil, err
+ }
+
+ query := fmt.Sprintf(`SELECT "task_id", "datasource", "type" FROM sys.tasks WHERE "runner_status" = 'RUNNING' AND "location" LIKE '%s:%%'`, hostname)
+ body, err := api.ExecuteSQL(query)
+ if err != nil {
+ return nil, err
+ }
+
+ var rows []runningTaskInfo
+ if err := json.Unmarshal(body, &rows); err != nil {
+ return nil, fmt.Errorf("failed to decode running tasks SQL response: %w", err)
+ }
+
+ taskIDs := make([]string, 0, len(rows))
+ for _, row := range rows {
+ if row.TaskID != "" {
+ taskIDs = append(taskIDs, row.TaskID)
+ }
+ }
+ return taskIDs, nil
+}
+
+func isMiddleManagerDrained(api middleManagerDruidAPI, workerHost string) (bool, error) {
+ hostname, err := validateWorkerHostnameForSQL(workerHost)
+ if err != nil {
+ return false, err
+ }
+
+ query := fmt.Sprintf(`SELECT COUNT(*) AS "cnt" FROM sys.tasks WHERE "runner_status" = 'RUNNING' AND "location" LIKE '%s:%%' AND "type" IN ('index_kafka', 'index_kinesis')`, hostname)
+ body, err := api.ExecuteSQL(query)
+ if err != nil {
+ return false, err
+ }
+
+ var rows []struct {
+ Cnt int `json:"cnt"`
+ }
+ if err := json.Unmarshal(body, &rows); err != nil {
+ return false, fmt.Errorf("failed to decode drain check SQL response: %w", err)
+ }
+ return len(rows) == 0 || rows[0].Cnt == 0, nil
+}
+
+func validateWorkerHostnameForSQL(workerHost string) (string, error) {
+ hostname := stripPort(workerHost)
+ if !workerHostnamePattern.MatchString(hostname) {
+ return "", fmt.Errorf("invalid MiddleManager worker hostname %q", hostname)
+ }
+ return hostname, nil
+}
+
+func stripPort(hostPort string) string {
+ idx := strings.LastIndex(hostPort, ":")
+ if idx < 0 {
+ return hostPort
+ }
+ return hostPort[:idx]
+}
+
+func buildMiddleManagerWorkerHost(podName, serviceName, namespace string, port int32) string {
+ return fmt.Sprintf("%s.%s.%s.svc.cluster.local:%d", podName, serviceName, namespace, port)
+}
+
+func getDruidPortFromStatefulSet(ctx context.Context, sdk client.Client, statefulSetName, namespace string) (int32, error) {
+ var sts appsv1.StatefulSet
+ if err := sdk.Get(ctx, types.NamespacedName{Name: statefulSetName, Namespace: namespace}, &sts); err != nil {
+ return 0, err
+ }
+ for _, container := range sts.Spec.Template.Spec.Containers {
+ for _, port := range container.Ports {
+ if port.Name == "druid-port" {
+ return port.ContainerPort, nil
+ }
+ }
+ }
+ return 0, fmt.Errorf("druid-port not found in StatefulSet %q pod template", statefulSetName)
+}
+
+func getOutdatedMiddleManagerPods(ctx context.Context, sdk client.Client, statefulSetName, namespace, currentRevision string) ([]v1.Pod, error) {
+ var podList v1.PodList
+ if err := sdk.List(ctx, &podList, client.InNamespace(namespace), client.MatchingLabels{
+ "nodeSpecUniqueStr": statefulSetName,
+ }); err != nil {
+ return nil, fmt.Errorf("failed to list MiddleManager pods: %w", err)
+ }
+
+ outdatedPods := make([]v1.Pod, 0, len(podList.Items))
+ for _, pod := range podList.Items {
+ if pod.Labels["controller-revision-hash"] == currentRevision {
+ outdatedPods = append(outdatedPods, pod)
+ }
+ }
+ return outdatedPods, nil
+}
+
+func sortPodsDescending(pods []v1.Pod) {
+ sort.SliceStable(pods, func(i, j int) bool {
+ return extractPodOrdinal(pods[i].Name) > extractPodOrdinal(pods[j].Name)
+ })
+}
+
+func extractPodOrdinal(name string) int32 {
+ re := regexp.MustCompile(`\d+$`)
+ match := re.FindString(name)
+ if match == "" {
+ return -1
+ }
+ num, err := strconv.Atoi(match)
+ if err != nil {
+ return -1
+ }
+ return int32(num)
+}
+
+func isPodReady(pod v1.Pod) bool {
+ if pod.Status.Phase != v1.PodRunning {
+ return false
+ }
+ for _, condition := range pod.Status.Conditions {
+ if condition.Type == v1.PodReady && condition.Status == v1.ConditionTrue {
+ return true
+ }
+ }
+ return false
+}
diff --git a/controllers/druid/middle_manager_drain_test.go b/controllers/druid/middle_manager_drain_test.go
new file mode 100644
index 00000000..acd69cde
--- /dev/null
+++ b/controllers/druid/middle_manager_drain_test.go
@@ -0,0 +1,236 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+*/
+package druid
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/apache/druid-operator/apis/druid/v1alpha1"
+ internalhttp "github.com/apache/druid-operator/pkg/http"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ appsv1 "k8s.io/api/apps/v1"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/client-go/kubernetes/scheme"
+ "k8s.io/client-go/tools/record"
+ "sigs.k8s.io/controller-runtime/pkg/client"
+ "sigs.k8s.io/controller-runtime/pkg/client/fake"
+)
+
+type fakeMiddleManagerDruidAPI struct {
+ sqlResponses map[string][]byte
+ payloads map[string]*taskPayloadResponse
+ disabled []string
+ enabled []string
+ handoffs map[string][]int
+}
+
+func (f *fakeMiddleManagerDruidAPI) DisableWorker(workerHost string) error {
+ f.disabled = append(f.disabled, workerHost)
+ return nil
+}
+
+func (f *fakeMiddleManagerDruidAPI) EnableWorker(workerHost string) error {
+ f.enabled = append(f.enabled, workerHost)
+ return nil
+}
+
+func (f *fakeMiddleManagerDruidAPI) GetTaskPayload(taskID string) (*taskPayloadResponse, error) {
+ return f.payloads[taskID], nil
+}
+
+func (f *fakeMiddleManagerDruidAPI) TriggerTaskGroupHandoff(supervisorID string, taskGroupIDs []int) error {
+ if f.handoffs == nil {
+ f.handoffs = map[string][]int{}
+ }
+ f.handoffs[supervisorID] = taskGroupIDs
+ return nil
+}
+
+func (f *fakeMiddleManagerDruidAPI) ExecuteSQL(query string) ([]byte, error) {
+ return f.sqlResponses["default"], nil
+}
+
+func taskPayload(supervisorID string, taskGroupID int) *taskPayloadResponse {
+ payload := &taskPayloadResponse{}
+ payload.Payload.DataSource = supervisorID
+ payload.Payload.IOConfig.TaskGroupID = &taskGroupID
+ return payload
+}
+
+func TestMiddleManagerDruidHTTPAPIEscapesPathSegments(t *testing.T) {
+ requestURIs := []string{}
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ requestURIs = append(requestURIs, r.RequestURI)
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte(`{"task":"task/a","payload":{"dataSource":"wiki","ioConfig":{"taskGroupId":1}}}`))
+ }))
+ defer server.Close()
+
+ api := &middleManagerDruidHTTPAPI{
+ baseURL: server.URL,
+ httpClient: internalhttp.NewHTTPClient(server.Client(), &internalhttp.Auth{}),
+ }
+
+ require.NoError(t, api.DisableWorker("mm-0.druid-mm.ns.svc.cluster.local:8091"))
+ _, err := api.GetTaskPayload("task/with/slashes")
+ require.NoError(t, err)
+ require.NoError(t, api.TriggerTaskGroupHandoff("supervisor/with/slash", []int{1}))
+
+ require.Len(t, requestURIs, 3)
+ assert.Contains(t, requestURIs[0], "mm-0.druid-mm.ns.svc.cluster.local:8091")
+ assert.Contains(t, requestURIs[1], "task%2Fwith%2Fslashes")
+ assert.Contains(t, requestURIs[2], "supervisor%2Fwith%2Fslash")
+}
+
+func TestDrainMiddleManagerTriggersDeduplicatedHandoffs(t *testing.T) {
+ api := &fakeMiddleManagerDruidAPI{
+ sqlResponses: map[string][]byte{
+ "default": []byte(`[
+ {"task_id":"task-0","datasource":"wiki","type":"index_kafka"},
+ {"task_id":"task-1","datasource":"wiki","type":"index_kafka"},
+ {"task_id":"task-duplicate","datasource":"wiki","type":"index_kafka"}
+ ]`),
+ },
+ payloads: map[string]*taskPayloadResponse{
+ "task-0": taskPayload("wiki", 1),
+ "task-1": taskPayload("wiki", 2),
+ "task-duplicate": taskPayload("wiki", 1),
+ },
+ }
+
+ workerHost := "mm-0.druid-mm.druid.svc.cluster.local:8091"
+ require.NoError(t, drainMiddleManager(api, workerHost))
+
+ assert.Equal(t, []string{workerHost}, api.disabled)
+ assert.Equal(t, []int{1, 2}, api.handoffs["wiki"])
+}
+
+func TestValidateWorkerHostnameForSQLRejectsUnsafeHost(t *testing.T) {
+ _, err := validateWorkerHostnameForSQL("mm-0.druid-mm.druid.svc.cluster.local:8091")
+ require.NoError(t, err)
+
+ _, err = validateWorkerHostnameForSQL("mm-0.bad'host.druid.svc.cluster.local:8091")
+ require.Error(t, err)
+}
+
+func TestNormalizeMiddleManagerDrainConfig(t *testing.T) {
+ assert.Equal(t, middleManagerDrainConfig{
+ DrainTimeout: defaultMiddleManagerDrainTimeout,
+ PodReadyTimeout: defaultMiddleManagerPodReadyTimeout,
+ }, normalizeMiddleManagerDrainConfig(nil))
+
+ config := normalizeMiddleManagerDrainConfig(&v1alpha1.MiddleManagerDrainStrategy{
+ DrainTimeout: metav1.Duration{Duration: 2 * time.Hour},
+ PodReadyTimeout: metav1.Duration{Duration: 10 * time.Minute},
+ })
+ assert.Equal(t, 2*time.Hour, config.DrainTimeout)
+ assert.Equal(t, 10*time.Minute, config.PodReadyTimeout)
+}
+
+func TestMiddleManagerDrainStatefulSetUpdaterFnBlocksRollout(t *testing.T) {
+ replicas := int32(3)
+ sts := &appsv1.StatefulSet{
+ ObjectMeta: metav1.ObjectMeta{Name: "druid-mm"},
+ Spec: appsv1.StatefulSetSpec{
+ Replicas: &replicas,
+ UpdateStrategy: appsv1.StatefulSetUpdateStrategy{
+ Type: appsv1.RollingUpdateStatefulSetStrategyType,
+ },
+ },
+ }
+
+ middleManagerDrainStatefulSetUpdaterFn(nil, sts)
+
+ require.NotNil(t, sts.Spec.UpdateStrategy.RollingUpdate)
+ require.NotNil(t, sts.Spec.UpdateStrategy.RollingUpdate.Partition)
+ assert.Equal(t, replicas, *sts.Spec.UpdateStrategy.RollingUpdate.Partition)
+}
+
+func TestContinueMiddleManagerDrainCycleBlocksOnPodReadyTimeout(t *testing.T) {
+ require.NoError(t, v1alpha1.AddToScheme(scheme.Scheme))
+
+ drd := &v1alpha1.Druid{
+ ObjectMeta: metav1.ObjectMeta{
+ Name: "druid",
+ Namespace: "druid",
+ },
+ Status: v1alpha1.DruidClusterStatus{
+ MiddleManagerDrain: &v1alpha1.MiddleManagerDrainStatus{
+ StatefulSet: "druid-mm",
+ Phase: middleManagerDrainPhaseWaitingForPod,
+ PodName: "druid-mm-0",
+ },
+ },
+ }
+ sts := &appsv1.StatefulSet{
+ ObjectMeta: metav1.ObjectMeta{
+ Name: "druid-mm",
+ Namespace: "druid",
+ },
+ Status: appsv1.StatefulSetStatus{
+ CurrentRevision: "old",
+ UpdateRevision: "new",
+ },
+ }
+
+ k8sClient := fake.NewClientBuilder().
+ WithScheme(scheme.Scheme).
+ WithObjects(drd, sts).
+ WithStatusSubresource(drd).
+ Build()
+
+ state := &middleManagerDrainState{
+ Phase: middleManagerDrainPhaseWaitingForPod,
+ PodName: "druid-mm-0",
+ PodOrdinal: 0,
+ LastUpdateTime: time.Now().Add(-31 * time.Minute),
+ }
+
+ err := continueMiddleManagerDrainCycle(
+ context.Background(),
+ k8sClient,
+ drd,
+ sts,
+ &fakeMiddleManagerDruidAPI{},
+ state,
+ 8091,
+ middleManagerDrainConfig{DrainTimeout: time.Hour, PodReadyTimeout: 30 * time.Minute},
+ EmitEventFuncs{record.NewFakeRecorder(10)},
+ )
+
+ require.Error(t, err)
+ blockedState, exists := getMiddleManagerDrainState("druid", "druid", "druid-mm")
+ require.True(t, exists)
+ assert.Equal(t, middleManagerDrainPhaseBlocked, blockedState.Phase)
+
+ var updated v1alpha1.Druid
+ require.NoError(t, k8sClient.Get(context.Background(), clientObjectKey("druid", "druid"), &updated))
+ require.NotNil(t, updated.Status.MiddleManagerDrain)
+ assert.Equal(t, middleManagerDrainPhaseBlocked, updated.Status.MiddleManagerDrain.Phase)
+ assert.Contains(t, updated.Status.MiddleManagerDrain.Message, "Timed out")
+}
+
+func clientObjectKey(namespace, name string) client.ObjectKey {
+ return client.ObjectKey{Namespace: namespace, Name: name}
+}
diff --git a/controllers/druid/status.go b/controllers/druid/status.go
index da940e84..357a1d24 100644
--- a/controllers/druid/status.go
+++ b/controllers/druid/status.go
@@ -67,7 +67,9 @@ func druidClusterStatusPatcher(ctx context.Context, sdk client.Client, updatedSt
if err != nil {
return fmt.Errorf("failed to serialize status patch to bytes: %v", err)
}
- _ = writers.Patch(ctx, sdk, m, m, true, client.RawPatch(types.MergePatchType, patchBytes), emitEvent)
+ if err := writers.Patch(ctx, sdk, m, m, true, client.RawPatch(types.MergePatchType, patchBytes), emitEvent); err != nil {
+ return err
+ }
}
return nil
}
diff --git a/controllers/druid/suite_test.go b/controllers/druid/suite_test.go
index 81d311b9..01883345 100644
--- a/controllers/druid/suite_test.go
+++ b/controllers/druid/suite_test.go
@@ -99,7 +99,8 @@ var _ = BeforeSuite(func() {
Expect(k8sClient).NotTo(BeNil())
k8sManager, err := ctrl.NewManager(cfg, ctrl.Options{
- Scheme: scheme.Scheme,
+ Scheme: scheme.Scheme,
+ MetricsBindAddress: "0",
})
Expect(err).ToNot(HaveOccurred())
diff --git a/docs/api_specifications/druid.md b/docs/api_specifications/druid.md
index 687e50f0..f0ad39c1 100644
--- a/docs/api_specifications/druid.md
+++ b/docs/api_specifications/druid.md
@@ -1,21 +1,3 @@
-
Druid API reference
Packages:
@@ -24,6 +6,20 @@ under the License.
druid.apache.org/v1alpha1
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+“License”); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+“AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
Resource Types:
AdditionalContainer
@@ -784,6 +780,22 @@ This will be done only for update actions.
+middleManagerDrainStrategy
+
+
+MiddleManagerDrainStrategy
+
+
+ |
+
+(Optional)
+ MiddleManagerDrainStrategy enables operator-managed draining before
+MiddleManager StatefulSet pods are rolled. If nil, MiddleManagers use the
+standard StatefulSet rolling update behavior.
+ |
+
+
+
defaultProbes
bool
@@ -1066,6 +1078,18 @@ Important: Run “make” to regenerate code after modifying this file
|
|
+
+
+middleManagerDrain
+
+
+MiddleManagerDrainStatus
+
+
+ |
+
+ |
+
@@ -2614,6 +2638,22 @@ This will be done only for update actions.
+middleManagerDrainStrategy
+
+
+MiddleManagerDrainStrategy
+
+
+ |
+
+(Optional)
+ MiddleManagerDrainStrategy enables operator-managed draining before
+MiddleManager StatefulSet pods are rolled. If nil, MiddleManagers use the
+standard StatefulSet rolling update behavior.
+ |
+
+
+
defaultProbes
bool
@@ -2891,6 +2931,151 @@ encoding/json.RawMessage
+MiddleManagerDrainStatus
+
+
+(Appears on:
+DruidClusterStatus)
+
+MiddleManagerDrainStatus reports an in-progress MiddleManager drain rollout.
+
+MiddleManagerDrainStrategy
+
+
+(Appears on:
+DruidSpec)
+
+MiddleManagerDrainStrategy configures operator-managed draining before a
+MiddleManager StatefulSet pod is rolled to a new revision.
+
ZookeeperSpec
diff --git a/docs/druid_cr.md b/docs/druid_cr.md
index a0820625..3184e134 100644
--- a/docs/druid_cr.md
+++ b/docs/druid_cr.md
@@ -31,6 +31,11 @@ spec:
# more information in features.md and in druid documentation
# http://druid.io/docs/latest/operations/rolling-updates.html
rollingDeploy: true
+ # Optional: drain MiddleManager StatefulSet pods before replacing them during rolling updates.
+ # Omit this field to use standard StatefulSet rolling update behavior.
+ middleManagerDrainStrategy:
+ drainTimeout: 1h
+ podReadyTimeout: 30m
# Image for druid, Required Key
image: apache/druid:25.0.0
....
diff --git a/docs/features.md b/docs/features.md
index 6f8b74b3..2eab0339 100644
--- a/docs/features.md
+++ b/docs/features.md
@@ -23,6 +23,7 @@ under the License.
- [Finalizer in Druid CR](#finalizer-in-druid-cr)
- [Deletion of Orphan PVCs](#deletion-of-orphan-pvcs)
- [Rolling Deploy](#rolling-deploy)
+- [MiddleManager Drain Strategy](#middlemanager-drain-strategy)
- [Force Delete of Sts Pods](#force-delete-of-sts-pods)
- [Horizontal Scaling of Druid Pods](#horizontal-scaling-of-druid-pods)
- [Volume Expansion of Druid Pods Running As StatefulSets](#volume-expansion-of-druid-pods-running-as-statefulsets)
@@ -68,6 +69,39 @@ Default updates are done in parallel. Since cluster creation does not require a
in parallel anyway. To enable this feature, set `rollingDeploy: true` in the Druid CR.
⚠️ This feature is enabled by default.
+## MiddleManager Drain Strategy
+For streaming ingestion clusters, a regular StatefulSet rolling update can terminate MiddleManager pods while they still
+own running Kafka or Kinesis indexing tasks. The operator can instead drain each MiddleManager pod before Kubernetes
+replaces it.
+
+Enable this opt-in behavior by setting `spec.middleManagerDrainStrategy`. The field uses pointer presence: if the field
+is omitted, MiddleManagers use the standard StatefulSet rolling update behavior.
+
+```yaml
+spec:
+ rollingDeploy: true
+ middleManagerDrainStrategy:
+ drainTimeout: 1h
+ podReadyTimeout: 30m
+```
+
+When enabled, the operator:
+
+1. Blocks the MiddleManager StatefulSet rollout by setting the rolling update partition to the replica count.
+2. Selects one outdated MiddleManager pod at a time, highest ordinal first.
+3. Disables the worker through the Druid Overlord API.
+4. Finds running streaming tasks with Druid SQL and triggers supervisor task-group handoff.
+5. Waits for running Kafka/Kinesis tasks on that pod to drain, up to `drainTimeout`.
+6. Lowers the StatefulSet partition for that pod and waits for the replacement pod to be ready on the new revision.
+7. Re-enables the worker and proceeds to the next pod on the following reconcile.
+
+The strategy is supported only for MiddleManager nodes running as StatefulSets. If configured for a MiddleManager
+Deployment, the operator logs a warning and uses the normal Deployment rollout path.
+
+This feature requires the operator to reach the Druid Router service and use the Druid API credentials configured in
+`spec.auth`, when authentication is enabled. If the Router service cannot be discovered, the rollout is blocked and the
+error is surfaced by the reconcile loop.
+
## Force Delete of Sts Pods
During upgradeS, if THE StatefulSet is set to `OrderedReady` - the StatefulSet controller will not recover from
crash-loopback state. The issues is referenced [here](https://github.com/kubernetes/kubernetes/issues/67250).
diff --git a/pkg/druidapi/druidapi.go b/pkg/druidapi/druidapi.go
index 3cf4749b..028f87c9 100644
--- a/pkg/druidapi/druidapi.go
+++ b/pkg/druidapi/druidapi.go
@@ -141,6 +141,18 @@ func MakePath(baseURL, componentType, apiType string, additionalPaths ...string)
return u.String()
}
+// MakeSQLPath constructs the path for Druid's SQL API.
+func MakeSQLPath(baseURL string) string {
+ u, err := url.Parse(baseURL)
+ if err != nil {
+ fmt.Println("Error parsing URL:", err)
+ return ""
+ }
+
+ u.Path = path.Join("druid", "v2", "sql")
+ return u.String()
+}
+
// GetRouterSvcUrl retrieves the URL of the Druid router service.
// Parameters:
//
diff --git a/pkg/druidapi/druidapi_test.go b/pkg/druidapi/druidapi_test.go
index ebcb6dbd..b7d80e6f 100644
--- a/pkg/druidapi/druidapi_test.go
+++ b/pkg/druidapi/druidapi_test.go
@@ -182,3 +182,36 @@ func TestMakePath(t *testing.T) {
})
}
}
+
+func TestMakeSQLPath(t *testing.T) {
+ tests := []struct {
+ name string
+ baseURL string
+ expected string
+ }{
+ {
+ name: "RouterService",
+ baseURL: "http://example-druid-service",
+ expected: "http://example-druid-service/druid/v2/sql",
+ },
+ {
+ name: "BaseURLWithPath",
+ baseURL: "http://example-druid-service/base",
+ expected: "http://example-druid-service/druid/v2/sql",
+ },
+ {
+ name: "EmptyBaseURL",
+ baseURL: "",
+ expected: "druid/v2/sql",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ actual := MakeSQLPath(tt.baseURL)
+ if actual != tt.expected {
+ t.Errorf("MakeSQLPath() = %v, expected %v", actual, tt.expected)
+ }
+ })
+ }
+}
|