Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import com.datastax.oss.driver.api.core.ConsistencyLevel;
import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.RequestRoutingType;
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
import com.datastax.oss.driver.api.core.config.DriverExecutionProfile;
import com.datastax.oss.driver.api.core.context.DriverContext;
Expand Down Expand Up @@ -63,6 +64,9 @@
import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
Expand All @@ -71,8 +75,10 @@
import java.util.Queue;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.IntUnaryOperator;
import java.util.stream.Collectors;
import net.jcip.annotations.ThreadSafe;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -113,6 +119,11 @@
@ThreadSafe
public class BasicLoadBalancingPolicy implements LoadBalancingPolicy {

public enum RequestRoutingMethod {
REGULAR,
PRESERVE_REPLICA_ORDER
}

private static final Logger LOG = LoggerFactory.getLogger(BasicLoadBalancingPolicy.class);

protected static final IntUnaryOperator INCREMENT = i -> (i == Integer.MAX_VALUE) ? 0 : i + 1;
Expand All @@ -127,6 +138,7 @@ public class BasicLoadBalancingPolicy implements LoadBalancingPolicy {
private final int maxNodesPerRemoteDc;
private final boolean allowDcFailoverForLocalCl;
private final ConsistencyLevel defaultConsistencyLevel;
private final RequestRoutingMethod lwtRequestRoutingMethod;

// private because they should be set in init() and never be modified after
private volatile DistanceReporter distanceReporter;
Expand Down Expand Up @@ -154,6 +166,34 @@ public BasicLoadBalancingPolicy(@NonNull DriverContext context, @NonNull String
new LinkedHashSet<>(
profile.getStringList(
DefaultDriverOption.LOAD_BALANCING_DC_FAILOVER_PREFERRED_REMOTE_DCS));
this.lwtRequestRoutingMethod = parseLwtRequestRoutingMethod();
}

@NonNull
private RequestRoutingMethod parseLwtRequestRoutingMethod() {
String methodString =
profile.getString(DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD);
try {
return RequestRoutingMethod.valueOf(methodString.toUpperCase());
} catch (IllegalArgumentException e) {
LOG.warn(
"[{}] Unknown request routing method '{}', defaulting to PRESERVE_REPLICA_ORDER",
logPrefix,
methodString);
return RequestRoutingMethod.PRESERVE_REPLICA_ORDER;
}
}

@NonNull
public RequestRoutingMethod getRequestRoutingMethod(@Nullable Request request) {
if (request == null) {
return RequestRoutingMethod.REGULAR;
}
if (request.getRequestRoutingType() == RequestRoutingType.LWT) {
return lwtRequestRoutingMethod;
} else {
return RequestRoutingMethod.REGULAR;
}
}

/**
Expand Down Expand Up @@ -260,6 +300,17 @@ protected NodeDistanceEvaluator createNodeDistanceEvaluator(
@NonNull
@Override
public Queue<Node> newQueryPlan(@Nullable Request request, @Nullable Session session) {
switch (getRequestRoutingMethod(request)) {
case PRESERVE_REPLICA_ORDER:
return newQueryPlanPreserveReplicas(request, session);
case REGULAR:
default:
return newQueryPlanRegular(request, session);
}
}

@NonNull
protected Queue<Node> newQueryPlanRegular(@Nullable Request request, @Nullable Session session) {
// Take a snapshot since the set is concurrent:
Object[] currentNodes = liveNodes.dc(localDc).toArray();

Expand Down Expand Up @@ -294,6 +345,116 @@ public Queue<Node> newQueryPlan(@Nullable Request request, @Nullable Session ses
return maybeAddDcFailover(request, plan);
}

/**
* Builds a query plan that preserves replica order: local replicas, remote replicas, local
* non-replicas (rotated), remote non-replicas (rotated).
*/
@NonNull
protected Queue<Node> newQueryPlanPreserveReplicas(
@Nullable Request request, @Nullable Session session) {
List<Node> replicas = getReplicas(request, session);
String localDc = getLocalDatacenter();
List<Node> queryPlan = new ArrayList<>();

if (localDc == null) {
// No local DC: all replicas first, then rotated non-replicas
List<Node> allNodes = new ArrayList<>();
for (Object obj : getLiveNodes().dc(null).toArray()) {
allNodes.add((Node) obj);
}
queryPlan.addAll(replicas);
addRotatedNonReplicas(queryPlan, allNodes, replicas, request);
} else {
// With local DC: prioritize local, then remote
Map<String, List<Node>> nodesByDc = getAllNodesByDc();
addReplicasByDc(queryPlan, replicas, localDc);
addNonReplicasByDc(queryPlan, nodesByDc, replicas, localDc, request);
}

return new SimpleQueryPlan(queryPlan.toArray());
}

/** Collect all live nodes grouped by DC, with preferred remote DCs ordered first. */
private Map<String, List<Node>> getAllNodesByDc() {
Map<String, List<Node>> nodesByDc = new LinkedHashMap<>();
Set<String> allDcs = getLiveNodes().dcs();
// Add preferred remote DCs first (in configured order)
for (String dc : preferredRemoteDcs) {
if (allDcs.contains(dc)) {
nodesByDc.put(dc, dcNodeList(dc));
}
}
Comment on lines +381 to +386
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new preserve-replica-order logic adds explicit ordering for preferredRemoteDcs, but the updated LWT routing tests don’t appear to assert that preferred remote DCs are actually prioritized ahead of other remote DCs in the resulting plan. Add a unit test that configures multiple remote DCs plus LOAD_BALANCING_DC_FAILOVER_PREFERRED_REMOTE_DCS, then asserts the remote non-replica portion of the plan respects that configured order.

Copilot uses AI. Check for mistakes.
// Add remaining DCs (sorted for deterministic ordering)
allDcs.stream()
.sorted()
.filter(dc -> !nodesByDc.containsKey(dc))
.forEach(dc -> nodesByDc.put(dc, dcNodeList(dc)));
return nodesByDc;
}
Comment on lines +378 to +393
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ordering of allDcs iteration is undefined for a generic Set, so the “remaining DCs” portion of the query-plan order can become non-deterministic across JVM runs (and harder to reason about/debug). Consider using a deterministic order for non-preferred DCs (e.g., sort the remaining DC names, or use an order guaranteed by the underlying dcs() implementation if such a guarantee exists and is documented).

Copilot uses AI. Check for mistakes.

private List<Node> dcNodeList(String dc) {
List<Node> dcNodes = new ArrayList<>();
for (Object obj : getLiveNodes().dc(dc).toArray()) {
dcNodes.add((Node) obj);
}
return dcNodes;
}

/** Add replicas with local DC first, then remote DCs. */
private void addReplicasByDc(List<Node> queryPlan, List<Node> replicas, String localDc) {
replicas.stream()
.filter(r -> Objects.equals(r.getDatacenter(), localDc))
.forEach(queryPlan::add);
replicas.stream()
.filter(r -> !Objects.equals(r.getDatacenter(), localDc))
.forEach(queryPlan::add);
}

/** Add non-replicas with local DC first, then remote DCs (all rotated). */
private void addNonReplicasByDc(
List<Node> queryPlan,
Map<String, List<Node>> nodesByDc,
List<Node> replicas,
String localDc,
Request request) {
// Local DC non-replicas first
List<Node> localNodes = nodesByDc.get(localDc);
if (localNodes != null) {
addRotatedNonReplicas(queryPlan, localNodes, replicas, request);
}
// Remote DC non-replicas
for (Map.Entry<String, List<Node>> entry : nodesByDc.entrySet()) {
if (!Objects.equals(entry.getKey(), localDc)) {
addRotatedNonReplicas(queryPlan, entry.getValue(), replicas, request);
}
}
}

/** Add non-replica nodes from given list with rotation. */
private void addRotatedNonReplicas(
List<Node> queryPlan, List<Node> nodes, List<Node> replicas, Request request) {
List<Node> nonReplicas =
nodes.stream().filter(n -> !replicas.contains(n)).collect(Collectors.toList());
if (!nonReplicas.isEmpty()) {
rotateNonReplicas(nonReplicas, request);
queryPlan.addAll(nonReplicas);
}
}
Comment on lines +434 to +442
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replicas.contains(n) inside a stream makes non-replica selection O(N*M) and can become expensive with larger node/replica sets. Convert replicas to a Set<Node> once (e.g., in newQueryPlanPreserveReplicas) and use that for membership checks when building nonReplicas.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The replicas list is bounded by the replication factor, which is typically 3–5. With such small lists, List.contains() is effectively O(1) in practice and likely faster than the overhead of creating a HashSet (hashing, boxing, allocation). Not worth optimizing.


/** Rotates nodes based on routing key (consistent) or randomly. */
private void rotateNonReplicas(List<Node> nodes, @Nullable Request request) {
if (nodes.size() <= 1) return;

int rotationAmount =
(request != null && request.getRoutingKey() != null)
? (request.getRoutingKey().hashCode() & 0x7fffffff) % nodes.size()
: randomNextInt(nodes.size());

if (rotationAmount > 0) {
Collections.rotate(nodes, -rotationAmount);
}
}

@NonNull
protected List<Node> getReplicas(@Nullable Request request, @Nullable Session session) {
if (request == null || session == null) {
Expand Down Expand Up @@ -441,6 +602,11 @@ protected Object[] computeNodes() {
return new CompositeQueryPlan(queryPlans);
}

/** Exposed as a protected method so that it can be accessed by tests */
protected int randomNextInt(int bound) {
return ThreadLocalRandom.current().nextInt(bound);
}

/** Exposed as a protected method so that it can be accessed by tests */
protected void shuffleHead(Object[] currentNodes, int headLength) {
ArrayUtils.shuffleHead(currentNodes, headLength);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.MINUTES;

import com.datastax.oss.driver.api.core.RequestRoutingType;
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
import com.datastax.oss.driver.api.core.config.DriverExecutionProfile;
import com.datastax.oss.driver.api.core.context.DriverContext;
Expand Down Expand Up @@ -48,7 +47,6 @@
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicLongArray;
import net.jcip.annotations.ThreadSafe;
import org.slf4j.Logger;
Expand Down Expand Up @@ -96,11 +94,6 @@
@ThreadSafe
public class DefaultLoadBalancingPolicy extends BasicLoadBalancingPolicy implements RequestTracker {

public enum RequestRoutingMethod {
REGULAR,
PRESERVE_REPLICA_ORDER
}

private static final Logger LOG = LoggerFactory.getLogger(DefaultLoadBalancingPolicy.class);

private static final long NEWLY_UP_INTERVAL_NANOS = MINUTES.toNanos(1);
Expand All @@ -110,31 +103,14 @@ public enum RequestRoutingMethod {
protected final ConcurrentMap<Node, NodeResponseRateSample> responseTimes;
protected final Map<Node, Long> upTimes = new ConcurrentHashMap<>();
private final boolean avoidSlowReplicas;
private final RequestRoutingMethod lwtRequestRoutingMethod;

public DefaultLoadBalancingPolicy(@NonNull DriverContext context, @NonNull String profileName) {
super(context, profileName);
this.avoidSlowReplicas =
profile.getBoolean(DefaultDriverOption.LOAD_BALANCING_POLICY_SLOW_AVOIDANCE, true);
this.lwtRequestRoutingMethod = getDefaultLWTRequestRoutingMethod();
this.responseTimes = new MapMaker().weakKeys().makeMap();
}

@NonNull
private RequestRoutingMethod getDefaultLWTRequestRoutingMethod() {
String methodString =
profile.getString(DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD);
try {
return RequestRoutingMethod.valueOf(methodString.toUpperCase());
} catch (IllegalArgumentException e) {
LOG.warn(
"[{}] Unknown request routing method '{}', defaulting to PRESERVE_REPLICA_ORDER",
logPrefix,
methodString);
return RequestRoutingMethod.PRESERVE_REPLICA_ORDER;
}
}

@NonNull
@Override
public Optional<RequestTracker> getRequestTracker() {
Expand All @@ -151,52 +127,13 @@ protected Optional<String> discoverLocalDc(@NonNull Map<UUID, Node> nodes) {
return new MandatoryLocalDcHelper(context, profile, logPrefix).discoverLocalDc(nodes);
}

@NonNull
public RequestRoutingMethod getDefaultLWTRequestRoutingMethod(@Nullable Request request) {
if (request == null) {
return RequestRoutingMethod.REGULAR;
}
if (request.getRequestRoutingType() == RequestRoutingType.LWT) {
return lwtRequestRoutingMethod;
} else {
return RequestRoutingMethod.REGULAR;
}
}

@NonNull
@Override
public Queue<Node> newQueryPlan(@Nullable Request request, @Nullable Session session) {
switch (getDefaultLWTRequestRoutingMethod(request)) {
case PRESERVE_REPLICA_ORDER:
return newQueryPlanPreserveReplicas(request, session);
case REGULAR:
default:
return newQueryPlanRegular(request, session);
}
}

/**
* Builds a query plan that preserves the replica order as returned by the load balancing
* strategy, while pushing non-local replicas after local ones.
*/
@NonNull
public Queue<Node> newQueryPlanPreserveReplicas(
@Nullable Request request, @Nullable Session session) {
List<Node> replicas = getReplicas(request, session);
String localDc = getLocalDatacenter();
if (localDc == null || replicas.isEmpty()) {
return new SimpleQueryPlan(replicas.toArray());
}

return new SimpleQueryPlan(moveNonLocalReplicasToTheEnd(replicas, localDc));
}

/**
* Builds a query plan that prioritizes local replicas, shuffles them for balance, and then
* round-robins the remaining local nodes.
*/
@NonNull
public Queue<Node> newQueryPlanRegular(@Nullable Request request, @Nullable Session session) {
@Override
protected Queue<Node> newQueryPlanRegular(@Nullable Request request, @Nullable Session session) {
List<Node> replicas = getReplicas(request, session);
Object[] currentNodes = getLiveNodes().dc(getLocalDatacenter()).toArray();
int replicaCount = 0; // in currentNodes
Expand Down Expand Up @@ -228,26 +165,6 @@ public Queue<Node> newQueryPlanRegular(@Nullable Request request, @Nullable Sess
return maybeAddDcFailover(request, plan);
}

/**
* Returns a replica array with local-datacenter replicas first and remote replicas preserved at
* the end.
*/
private static Object[] moveNonLocalReplicasToTheEnd(List<Node> replicas, String localDc) {
Object[] orderedReplicas = new Object[replicas.size()];
int index = 0;
for (Node replica : replicas) {
if (Objects.equals(replica.getDatacenter(), localDc)) {
orderedReplicas[index++] = replica;
}
}
for (Node replica : replicas) {
if (!Objects.equals(replica.getDatacenter(), localDc)) {
orderedReplicas[index++] = replica;
}
}
return orderedReplicas;
}

private int[] moveReplicasToFront(Object[] currentNodes, List<Node> allReplicas) {
int replicaCount = 0, localRackReplicaCount = 0;
for (int i = 0; i < currentNodes.length; i++) {
Expand Down Expand Up @@ -329,7 +246,7 @@ private void avoidSlowReplicas(
// - the replica in first or second position is the most recent replica marked as UP and
// - dice roll 1d4 != 1
else if ((newestUpReplica == currentNodes[0] || newestUpReplica == currentNodes[1])
&& diceRoll1d4() != 1) {
&& randomNextInt(4) != 1) {

// Send it to the back of the replicas
ArrayUtils.bubbleDown(
Expand Down Expand Up @@ -370,11 +287,6 @@ protected long nanoTime() {
return System.nanoTime();
}

/** Exposed as a protected method so that it can be accessed by tests */
protected int diceRoll1d4() {
return ThreadLocalRandom.current().nextInt(4);
}

protected boolean isUnhealthy(@NonNull Node node, @NonNull Session session, long now) {
return isBusy(node, session) && isResponseRateInsufficient(node, now);
}
Expand Down
Loading
Loading