diff --git a/pkg/receive/config.go b/pkg/receive/config.go index a81508f50f2..69448f99555 100644 --- a/pkg/receive/config.go +++ b/pkg/receive/config.go @@ -10,6 +10,7 @@ import ( "encoding/json" "fmt" "io" + "math" "os" "path/filepath" "strings" @@ -118,6 +119,87 @@ func (e *Endpoint) unmarshal(data []byte) error { return nil } +// ShardSize represents a shard size that can be either an absolute integer count +// or a percentage of available shards. Percentages are only supported for the +// rendezvous algorithm. +type ShardSize struct { + Value int // absolute count (used when IsPercent=false) + Percent float64 // 0.0-1.0 (used when IsPercent=true) + IsPercent bool +} + +// IsZero returns true if neither an absolute value nor a percentage is set. +func (s ShardSize) IsZero() bool { + if s.IsPercent { + return s.Percent == 0 + } + return s.Value == 0 +} + +// ResolveCount resolves the shard size to an absolute count given a total. +// For percentages, returns max(1, int(total * percent)). +// For absolute values, returns the Value directly. +func (s ShardSize) ResolveCount(total int) int { + if s.IsPercent { + return int(math.Max(1, float64(total)*s.Percent)) + } + return s.Value +} + +// String returns a human-readable representation of the shard size. +func (s ShardSize) String() string { + if s.IsPercent { + return fmt.Sprintf("%.0f%%", s.Percent*100) + } + return fmt.Sprintf("%d", s.Value) +} + +// UnmarshalJSON supports both integer (e.g. 6) and percentage string (e.g. "50%") formats. +func (s *ShardSize) UnmarshalJSON(data []byte) error { + // Try integer first. + var intVal int + if err := json.Unmarshal(data, &intVal); err == nil { + s.Value = intVal + s.Percent = 0 + s.IsPercent = false + return nil + } + + // Try string (percentage format). + var strVal string + if err := json.Unmarshal(data, &strVal); err != nil { + return fmt.Errorf("shard_size must be an integer or a percentage string (e.g. \"50%%\"), got: %s", string(data)) + } + + strVal = strings.TrimSpace(strVal) + if !strings.HasSuffix(strVal, "%") { + return fmt.Errorf("shard_size string must end with '%%', got: %q", strVal) + } + + numStr := strings.TrimSuffix(strVal, "%") + var pct float64 + if _, err := fmt.Sscanf(numStr, "%f", &pct); err != nil { + return fmt.Errorf("invalid percentage value in shard_size: %q", strVal) + } + + if pct < 0 || pct > 100 { + return fmt.Errorf("shard_size percentage must be between 0 and 100, got: %s", strVal) + } + + s.Value = 0 + s.Percent = pct / 100.0 + s.IsPercent = true + return nil +} + +// MarshalJSON serializes the shard size back to JSON. +func (s ShardSize) MarshalJSON() ([]byte, error) { + if s.IsPercent { + return json.Marshal(fmt.Sprintf("%.0f%%", s.Percent*100)) + } + return json.Marshal(s.Value) +} + // HashringConfig represents the configuration for a hashring // a receive node knows about. type HashringConfig struct { @@ -132,14 +214,14 @@ type HashringConfig struct { } type ShuffleShardingOverrideConfig struct { - ShardSize int `json:"shard_size"` + ShardSize ShardSize `json:"shard_size"` Tenants []string `json:"tenants,omitempty"` TenantMatcherType tenantMatcher `json:"tenant_matcher_type,omitempty"` } type ShuffleShardingConfig struct { - ShardSize int `json:"shard_size"` - CacheSize int `json:"cache_size"` + ShardSize ShardSize `json:"shard_size"` + CacheSize int `json:"cache_size"` // ZoneAwarenessDisabled disables zone awareness. We still try to spread the load // across the available zones, but we don't try to balance the shards across zones. ZoneAwarenessDisabled bool `json:"zone_awareness_disabled"` diff --git a/pkg/receive/config_test.go b/pkg/receive/config_test.go index 5ce78e6514b..6e78d582b65 100644 --- a/pkg/receive/config_test.go +++ b/pkg/receive/config_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/pkg/errors" + "github.com/stretchr/testify/require" "github.com/efficientgo/core/testutil" ) @@ -123,3 +124,169 @@ func TestUnmarshalEndpointSlice(t *testing.T) { }) } } + +func TestShardSizeUnmarshalJSON(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + input string + expected ShardSize + expectErr bool + }{ + { + name: "integer value", + input: `6`, + expected: ShardSize{Value: 6}, + }, + { + name: "zero integer", + input: `0`, + expected: ShardSize{Value: 0}, + }, + { + name: "percentage string", + input: `"50%"`, + expected: ShardSize{Percent: 0.5, IsPercent: true}, + }, + { + name: "zero percentage", + input: `"0%"`, + expected: ShardSize{Percent: 0, IsPercent: true}, + }, + { + name: "100 percentage", + input: `"100%"`, + expected: ShardSize{Percent: 1.0, IsPercent: true}, + }, + { + name: "25 percentage", + input: `"25%"`, + expected: ShardSize{Percent: 0.25, IsPercent: true}, + }, + { + name: "invalid string without percent", + input: `"50"`, + expectErr: true, + }, + { + name: "negative percentage", + input: `"-10%"`, + expectErr: true, + }, + { + name: "over 100 percentage", + input: `"150%"`, + expectErr: true, + }, + { + name: "invalid type", + input: `true`, + expectErr: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + var s ShardSize + err := json.Unmarshal([]byte(tc.input), &s) + if tc.expectErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tc.expected, s) + }) + } +} + +func TestShardSizeMarshalJSON(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + input ShardSize + expected string + }{ + { + name: "integer value", + input: ShardSize{Value: 6}, + expected: `6`, + }, + { + name: "zero value", + input: ShardSize{}, + expected: `0`, + }, + { + name: "percentage", + input: ShardSize{Percent: 0.5, IsPercent: true}, + expected: `"50%"`, + }, + { + name: "100 percentage", + input: ShardSize{Percent: 1.0, IsPercent: true}, + expected: `"100%"`, + }, + } { + t.Run(tc.name, func(t *testing.T) { + data, err := json.Marshal(tc.input) + require.NoError(t, err) + require.Equal(t, tc.expected, string(data)) + }) + } +} + +func TestShardSizeIsZero(t *testing.T) { + t.Parallel() + + require.True(t, ShardSize{}.IsZero()) + require.True(t, ShardSize{Value: 0}.IsZero()) + require.True(t, ShardSize{Percent: 0, IsPercent: true}.IsZero()) + require.False(t, ShardSize{Value: 1}.IsZero()) + require.False(t, ShardSize{Percent: 0.5, IsPercent: true}.IsZero()) +} + +func TestShardSizeResolveCount(t *testing.T) { + t.Parallel() + + // Absolute value: returns Value directly regardless of total. + require.Equal(t, 6, ShardSize{Value: 6}.ResolveCount(100)) + require.Equal(t, 6, ShardSize{Value: 6}.ResolveCount(4)) + + // Percentage: max(1, total * pct). + require.Equal(t, 2, ShardSize{Percent: 0.5, IsPercent: true}.ResolveCount(4)) + require.Equal(t, 1, ShardSize{Percent: 0.25, IsPercent: true}.ResolveCount(4)) + require.Equal(t, 4, ShardSize{Percent: 1.0, IsPercent: true}.ResolveCount(4)) + // Very small percentage still returns at least 1. + require.Equal(t, 1, ShardSize{Percent: 0.01, IsPercent: true}.ResolveCount(4)) +} + +func TestShardSizeRoundTripJSON(t *testing.T) { + t.Parallel() + + // Test that ShardSize round-trips through full config JSON parsing. + cfgJSON := `[{ + "hashring": "test", + "endpoints": [{"address": "node1"}], + "shuffle_sharding_config": { + "shard_size": "50%", + "overrides": [ + {"shard_size": 6, "tenants": ["t1"]}, + {"shard_size": "25%", "tenants": ["t2"]} + ] + } + }]` + + configs, err := ParseConfig([]byte(cfgJSON)) + require.NoError(t, err) + require.Len(t, configs, 1) + + ssc := configs[0].ShuffleShardingConfig + require.True(t, ssc.ShardSize.IsPercent) + require.InDelta(t, 0.5, ssc.ShardSize.Percent, 0.001) + + require.Len(t, ssc.Overrides, 2) + require.False(t, ssc.Overrides[0].ShardSize.IsPercent) + require.Equal(t, 6, ssc.Overrides[0].ShardSize.Value) + require.True(t, ssc.Overrides[1].ShardSize.IsPercent) + require.InDelta(t, 0.25, ssc.Overrides[1].ShardSize.Percent, 0.001) +} diff --git a/pkg/receive/hashring.go b/pkg/receive/hashring.go index dd26e3a50e0..a132f6bc03d 100644 --- a/pkg/receive/hashring.go +++ b/pkg/receive/hashring.go @@ -483,7 +483,7 @@ func newShuffleShardHashring(baseRing Hashring, shuffleShardingConfig ShuffleSha maxNodesInAZ = max(maxNodesInAZ, count) } - if shuffleShardingConfig.ShardSize > maxNodesInAZ { + if !shuffleShardingConfig.ShardSize.IsPercent && shuffleShardingConfig.ShardSize.Value > maxNodesInAZ { level.Warn(l).Log( "msg", "Shard size is larger than the maximum number of nodes in any AZ; some tenants might get all not working nodes if that AZ goes down", "shard_size", shuffleShardingConfig.ShardSize, @@ -492,7 +492,7 @@ func newShuffleShardHashring(baseRing Hashring, shuffleShardingConfig ShuffleSha } for _, override := range shuffleShardingConfig.Overrides { - if override.ShardSize < maxNodesInAZ { + if override.ShardSize.IsPercent || override.ShardSize.Value < maxNodesInAZ { continue } level.Warn(l).Log( @@ -528,7 +528,7 @@ func (s *shuffleShardHashring) dedupedNodes() []Endpoint { } // getShardSize returns the shard size for a specific tenant, taking into account any overrides. -func (s *shuffleShardHashring) getShardSize(tenant string) int { +func (s *shuffleShardHashring) getShardSize(tenant string) ShardSize { for _, override := range s.shuffleShardingConfig.Overrides { if override.TenantMatcherType == TenantMatcherTypeExact { for _, t := range override.Tenants { @@ -646,7 +646,11 @@ func (s *shuffleShardHashring) getTenantShard(tenant string) (*ketamaHashring, e sort.Sort(sectionsByAZ[az]) } - ss := s.getShardSize(tenant) + shardSize := s.getShardSize(tenant) + if shardSize.IsPercent { + return nil, fmt.Errorf("percentage shard_size is not supported for ketama algorithm") + } + ss := shardSize.Value var take int if s.shuffleShardingConfig.ZoneAwarenessDisabled { take = ss @@ -867,15 +871,22 @@ func (s *shuffleShardHashring) getTenantShardRendezvous(tenant string) (Hashring return nil, errors.Wrap(err, "failed to extract shard structure") } - // shard_size is the total shard size; divide by numAZs to get per-AZ count. - totalShardSize := s.getShardSize(tenant) + // Determine per-AZ shard count based on shard size type. + shardSize := s.getShardSize(tenant) numAZs := len(azShardMap) - perAZShards := totalShardSize / numAZs // floor + var perAZShards int + if shardSize.IsPercent { + // Percentage: resolve directly against per-AZ common shard count. + perAZShards = shardSize.ResolveCount(len(commonShards)) + } else { + // Absolute: divide total by number of AZs. + perAZShards = shardSize.Value / numAZs // floor + } if perAZShards == 0 { - return nil, fmt.Errorf("shard size %d too small for %d AZs", totalShardSize, numAZs) + return nil, fmt.Errorf("shard size %s too small for %d AZs", shardSize, numAZs) } if perAZShards > len(commonShards) { - return nil, fmt.Errorf("per-AZ shard count %d (from total %d / %d AZs) exceeds available common shards (%d)", perAZShards, totalShardSize, numAZs, len(commonShards)) + return nil, fmt.Errorf("per-AZ shard count %d (from shard_size %s / %d AZs) exceeds available common shards (%d)", perAZShards, shardSize, numAZs, len(commonShards)) } // Select shards using rendezvous hashing @@ -978,7 +989,7 @@ func newHashring(algorithm HashringAlgorithm, endpoints []Endpoint, replicationF if err != nil { return nil, err } - if shuffleShardingConfig.ShardSize > 0 { + if !shuffleShardingConfig.ShardSize.IsZero() { return nil, fmt.Errorf("hashmod algorithm does not support shuffle sharding. Either use Ketama or remove shuffle sharding configuration") } return ringImpl, nil @@ -987,9 +998,12 @@ func newHashring(algorithm HashringAlgorithm, endpoints []Endpoint, replicationF if err != nil { return nil, err } - if shuffleShardingConfig.ShardSize > 0 { - if shuffleShardingConfig.ShardSize > len(endpoints) { - return nil, fmt.Errorf("shard size %d is larger than number of nodes in hashring %s (%d)", shuffleShardingConfig.ShardSize, hashring, len(endpoints)) + if !shuffleShardingConfig.ShardSize.IsZero() { + if shuffleShardingConfig.ShardSize.IsPercent { + return nil, fmt.Errorf("percentage shard_size is not supported for ketama algorithm") + } + if shuffleShardingConfig.ShardSize.Value > len(endpoints) { + return nil, fmt.Errorf("shard size %d is larger than number of nodes in hashring %s (%d)", shuffleShardingConfig.ShardSize.Value, hashring, len(endpoints)) } return newShuffleShardHashring(ringImpl, shuffleShardingConfig, replicationFactor, reg, hashring) } @@ -1000,9 +1014,15 @@ func newHashring(algorithm HashringAlgorithm, endpoints []Endpoint, replicationF return nil, err } numShardsGauge.WithLabelValues(hashring).Set(float64(ringImpl.numShards)) - if shuffleShardingConfig.ShardSize > 0 { - if shuffleShardingConfig.ShardSize > len(endpoints) { - return nil, fmt.Errorf("shard size %d is larger than number of nodes in hashring %s (%d)", shuffleShardingConfig.ShardSize, hashring, len(endpoints)) + if !shuffleShardingConfig.ShardSize.IsZero() { + if shuffleShardingConfig.ShardSize.IsPercent { + if shuffleShardingConfig.ShardSize.Percent <= 0 || shuffleShardingConfig.ShardSize.Percent > 1.0 { + return nil, fmt.Errorf("shard_size percentage must be between 0%% and 100%%, got: %s", shuffleShardingConfig.ShardSize) + } + } else { + if shuffleShardingConfig.ShardSize.Value > len(endpoints) { + return nil, fmt.Errorf("shard size %d is larger than number of nodes in hashring %s (%d)", shuffleShardingConfig.ShardSize.Value, hashring, len(endpoints)) + } } return newShuffleShardHashring(ringImpl, shuffleShardingConfig, replicationFactor, reg, hashring) } @@ -1012,7 +1032,7 @@ func newHashring(algorithm HashringAlgorithm, endpoints []Endpoint, replicationF level.Warn(l).Log("msg", "Unrecognizable hashring algorithm. Fall back to hashmod algorithm.", "hashring", hashring, "tenants", tenants) - if shuffleShardingConfig.ShardSize > 0 { + if !shuffleShardingConfig.ShardSize.IsZero() { return nil, fmt.Errorf("hashmod algorithm does not support shuffle sharding. Either use Ketama or remove shuffle sharding configuration") } return newSimpleHashring(endpoints) diff --git a/pkg/receive/hashring_test.go b/pkg/receive/hashring_test.go index 145c3fc7a2e..8508a8d9cd7 100644 --- a/pkg/receive/hashring_test.go +++ b/pkg/receive/hashring_test.go @@ -700,11 +700,11 @@ func TestShuffleShardHashring(t *testing.T) { }, tenant: "tenant-1", shuffleShardCfg: ShuffleShardingConfig{ - ShardSize: 2, + ShardSize: ShardSize{Value: 2}, Overrides: []ShuffleShardingOverrideConfig{ { Tenants: []string{"special-tenant"}, - ShardSize: 2, + ShardSize: ShardSize{Value: 2}, }, }, }, @@ -722,11 +722,11 @@ func TestShuffleShardHashring(t *testing.T) { }, tenant: "prefix-tenant", shuffleShardCfg: ShuffleShardingConfig{ - ShardSize: 2, + ShardSize: ShardSize{Value: 2}, Overrides: []ShuffleShardingOverrideConfig{ { Tenants: []string{"prefix*"}, - ShardSize: 3, + ShardSize: ShardSize{Value: 3}, TenantMatcherType: TenantMatcherGlob, }, }, @@ -745,11 +745,11 @@ func TestShuffleShardHashring(t *testing.T) { tenant: "prefix-tenant", err: `shard size 20 is larger than number of nodes in AZ`, shuffleShardCfg: ShuffleShardingConfig{ - ShardSize: 2, + ShardSize: ShardSize{Value: 2}, Overrides: []ShuffleShardingOverrideConfig{ { Tenants: []string{"prefix*"}, - ShardSize: 20, + ShardSize: ShardSize{Value: 20}, TenantMatcherType: TenantMatcherGlob, }, }, @@ -774,12 +774,12 @@ func TestShuffleShardHashring(t *testing.T) { // shuffle. What matters is: (1) exactly 3 nodes are used, and (2) the // selection is stable when scaling (tested in TestShuffleShardHashringStability). shuffleShardCfg: ShuffleShardingConfig{ - ShardSize: 1, + ShardSize: ShardSize{Value: 1}, ZoneAwarenessDisabled: true, Overrides: []ShuffleShardingOverrideConfig{ { Tenants: []string{"prefix*"}, - ShardSize: 3, + ShardSize: ShardSize{Value: 3}, TenantMatcherType: TenantMatcherGlob, }, }, @@ -962,7 +962,7 @@ func TestShuffleShardHashringStability(t *testing.T) { } shuffleShardCfg := ShuffleShardingConfig{ - ShardSize: tc.shardSize, + ShardSize: ShardSize{Value: tc.shardSize}, ZoneAwarenessDisabled: true, } diff --git a/pkg/receive/rendezvous_hashring_test.go b/pkg/receive/rendezvous_hashring_test.go index 84f6e37b0df..7ebafd03f5d 100644 --- a/pkg/receive/rendezvous_hashring_test.go +++ b/pkg/receive/rendezvous_hashring_test.go @@ -297,7 +297,7 @@ func TestRendezvousShuffleShardingBasic(t *testing.T) { require.NoError(t, err) cfg := ShuffleShardingConfig{ - ShardSize: 6, + ShardSize: ShardSize{Value: 6}, } shardRing, err := newShuffleShardHashring(baseRing, cfg, 3, prometheus.NewRegistry(), "test-rendezvous") require.NoError(t, err) @@ -348,7 +348,7 @@ func TestRendezvousShuffleShardingConsistency(t *testing.T) { require.NoError(t, err) cfg := ShuffleShardingConfig{ - ShardSize: 6, + ShardSize: ShardSize{Value: 6}, } shardRing, err := newShuffleShardHashring(baseRing, cfg, 3, prometheus.NewRegistry(), "test-consistency") require.NoError(t, err) @@ -384,7 +384,7 @@ func TestRendezvousShuffleShardingDifferentTenants(t *testing.T) { require.NoError(t, err) cfg := ShuffleShardingConfig{ - ShardSize: 9, + ShardSize: ShardSize{Value: 9}, } shardRing, err := newShuffleShardHashring(baseRing, cfg, 3, prometheus.NewRegistry(), "test-diff-tenants") require.NoError(t, err) @@ -423,7 +423,7 @@ func TestRendezvousShuffleShardingPreservesAlignment(t *testing.T) { require.NoError(t, err) cfg := ShuffleShardingConfig{ - ShardSize: 6, + ShardSize: ShardSize{Value: 6}, } shardRing, err := newShuffleShardHashring(baseRing, cfg, 3, prometheus.NewRegistry(), "test-preserves") require.NoError(t, err) @@ -472,7 +472,7 @@ func TestRendezvousShuffleShardingDataDistribution(t *testing.T) { require.NoError(t, err) cfg := ShuffleShardingConfig{ - ShardSize: 6, + ShardSize: ShardSize{Value: 6}, } shardRing, err := newShuffleShardHashring(baseRing, cfg, 3, prometheus.NewRegistry(), "test-distribution") require.NoError(t, err) @@ -582,7 +582,7 @@ func TestRendezvousShuffleShardingValidation(t *testing.T) { require.NoError(t, err) cfg := ShuffleShardingConfig{ - ShardSize: 30, // 30 / 3 AZs = 10 per-AZ, but only 5 shards available + ShardSize: ShardSize{Value: 30}, // 30 / 3 AZs = 10 per-AZ, but only 5 shards available } shardRing, err := newShuffleShardHashring(baseRing, cfg, 3, prometheus.NewRegistry(), "test-validation") require.NoError(t, err) @@ -591,3 +591,207 @@ func TestRendezvousShuffleShardingValidation(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "exceeds available common shards") } + +func TestRendezvousShuffleShardingPercentage(t *testing.T) { + t.Parallel() + + // Create 3 AZs with 4 shards each (12 endpoints total, 4 common shards). + endpoints := make([]Endpoint, 0, 12) + azs := []string{"az-a", "az-b", "az-c"} + for _, az := range azs { + for ord := 0; ord < 4; ord++ { + endpoints = append(endpoints, makeK8sEndpoint("pod-"+az, ord, az)) + } + } + + baseRing, err := newRendezvousHashring(endpoints, 3) + require.NoError(t, err) + + t.Run("50% of 4 shards = 2 per AZ", func(t *testing.T) { + cfg := ShuffleShardingConfig{ + ShardSize: ShardSize{Percent: 0.5, IsPercent: true}, + } + shardRing, err := newShuffleShardHashring(baseRing, cfg, 3, prometheus.NewRegistry(), "test-pct-50") + require.NoError(t, err) + + shard, err := shardRing.getTenantShardRendezvous("test-tenant") + require.NoError(t, err) + + nodes := shard.Nodes() + // 2 shards per AZ * 3 AZs = 6 endpoints. + require.Len(t, nodes, 6, "expected 6 endpoints (2 shards * 3 AZs)") + + shardsByAZ := make(map[string][]int) + for _, node := range nodes { + shardsByAZ[node.AZ] = append(shardsByAZ[node.AZ], extractShardFromAddress(t, node.Address)) + } + require.Len(t, shardsByAZ, 3) + for az, shards := range shardsByAZ { + require.Len(t, shards, 2, "AZ %s should have 2 shards", az) + } + + // All AZs should have the same shards. + var referenceShards []int + for _, shards := range shardsByAZ { + sort.Ints(shards) + if referenceShards == nil { + referenceShards = shards + } else { + require.Equal(t, referenceShards, shards) + } + } + }) + + t.Run("25% of 4 shards = 1 per AZ", func(t *testing.T) { + cfg := ShuffleShardingConfig{ + ShardSize: ShardSize{Percent: 0.25, IsPercent: true}, + } + shardRing, err := newShuffleShardHashring(baseRing, cfg, 3, prometheus.NewRegistry(), "test-pct-25") + require.NoError(t, err) + + shard, err := shardRing.getTenantShardRendezvous("test-tenant") + require.NoError(t, err) + + nodes := shard.Nodes() + // 1 shard per AZ * 3 AZs = 3 endpoints. + require.Len(t, nodes, 3, "expected 3 endpoints (1 shard * 3 AZs)") + + shardsByAZ := make(map[string][]int) + for _, node := range nodes { + shardsByAZ[node.AZ] = append(shardsByAZ[node.AZ], extractShardFromAddress(t, node.Address)) + } + require.Len(t, shardsByAZ, 3) + for az, shards := range shardsByAZ { + require.Len(t, shards, 1, "AZ %s should have 1 shard", az) + } + }) + + t.Run("100% of 4 shards = 4 per AZ", func(t *testing.T) { + cfg := ShuffleShardingConfig{ + ShardSize: ShardSize{Percent: 1.0, IsPercent: true}, + } + shardRing, err := newShuffleShardHashring(baseRing, cfg, 3, prometheus.NewRegistry(), "test-pct-100") + require.NoError(t, err) + + shard, err := shardRing.getTenantShardRendezvous("test-tenant") + require.NoError(t, err) + + nodes := shard.Nodes() + // 4 shards per AZ * 3 AZs = 12 endpoints (all of them). + require.Len(t, nodes, 12, "expected 12 endpoints (4 shards * 3 AZs)") + }) + + t.Run("percentage with override", func(t *testing.T) { + cfg := ShuffleShardingConfig{ + ShardSize: ShardSize{Percent: 0.5, IsPercent: true}, + Overrides: []ShuffleShardingOverrideConfig{ + { + Tenants: []string{"special-tenant"}, + ShardSize: ShardSize{Percent: 0.25, IsPercent: true}, + TenantMatcherType: TenantMatcherTypeExact, + }, + }, + } + shardRing, err := newShuffleShardHashring(baseRing, cfg, 3, prometheus.NewRegistry(), "test-pct-override") + require.NoError(t, err) + + // Default tenant gets 50%. + shard, err := shardRing.getTenantShardRendezvous("default-tenant") + require.NoError(t, err) + require.Len(t, shard.Nodes(), 6) // 2 shards * 3 AZs + + // Special tenant gets 25%. + shard, err = shardRing.getTenantShardRendezvous("special-tenant") + require.NoError(t, err) + require.Len(t, shard.Nodes(), 3) // 1 shard * 3 AZs + }) +} + +func TestRendezvousShuffleShardingPercentageConsistency(t *testing.T) { + t.Parallel() + + endpoints := make([]Endpoint, 0, 30) + azs := []string{"az-a", "az-b", "az-c"} + for _, az := range azs { + for ord := 0; ord < 10; ord++ { + endpoints = append(endpoints, makeK8sEndpoint("pod-"+az, ord, az)) + } + } + + baseRing, err := newRendezvousHashring(endpoints, 3) + require.NoError(t, err) + + cfg := ShuffleShardingConfig{ + ShardSize: ShardSize{Percent: 0.5, IsPercent: true}, + } + shardRing, err := newShuffleShardHashring(baseRing, cfg, 3, prometheus.NewRegistry(), "test-pct-consistency") + require.NoError(t, err) + + tenant := "consistency-tenant" + var firstShards []int + + for trial := 0; trial < 10; trial++ { + shard, err := shardRing.getTenantShardRendezvous(tenant) + require.NoError(t, err) + currentShards := extractShardsFromSubring(t, shard) + if firstShards == nil { + firstShards = currentShards + } else { + require.Equal(t, firstShards, currentShards, "same tenant should always get same shards") + } + } +} + +func TestKetamaRejectsPercentageShardSize(t *testing.T) { + t.Parallel() + + // Verify that ketama algorithm rejects percentage shard_size at config validation. + endpoints := []Endpoint{ + {Address: "node-1", AZ: "az-1"}, + {Address: "node-2", AZ: "az-1"}, + {Address: "node-3", AZ: "az-2"}, + {Address: "node-4", AZ: "az-2"}, + } + + cfg := []HashringConfig{ + { + Hashring: "test", + Endpoints: endpoints, + Algorithm: AlgorithmKetama, + ShuffleShardingConfig: ShuffleShardingConfig{ + ShardSize: ShardSize{Percent: 0.5, IsPercent: true}, + }, + }, + } + + _, err := NewMultiHashring(AlgorithmKetama, 2, cfg, prometheus.NewRegistry()) + require.Error(t, err) + require.Contains(t, err.Error(), "percentage shard_size is not supported for ketama algorithm") +} + +func TestRendezvousIntegerShardSizeBackwardCompatibility(t *testing.T) { + t.Parallel() + + // Verify that integer shard_size still works the same way with rendezvous. + endpoints := make([]Endpoint, 0, 15) + azs := []string{"az-a", "az-b", "az-c"} + for _, az := range azs { + for ord := 0; ord < 5; ord++ { + endpoints = append(endpoints, makeK8sEndpoint("pod-"+az, ord, az)) + } + } + + baseRing, err := newRendezvousHashring(endpoints, 3) + require.NoError(t, err) + + // shard_size=6 → 6/3 AZs = 2 shards per AZ → 6 endpoints total. + cfg := ShuffleShardingConfig{ + ShardSize: ShardSize{Value: 6}, + } + shardRing, err := newShuffleShardHashring(baseRing, cfg, 3, prometheus.NewRegistry(), "test-int-compat") + require.NoError(t, err) + + shard, err := shardRing.getTenantShardRendezvous("test-tenant") + require.NoError(t, err) + require.Len(t, shard.Nodes(), 6, "expected 6 endpoints (2 shards per AZ * 3 AZs)") +}