From e989b359252a2f9ff1cafff131caa5a88b406456 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Fri, 24 Apr 2026 09:52:20 -0400 Subject: [PATCH] fix: route tablet-aware batches via routed statement Batch statements can expose a routing key from one child statement while TokenAwarePolicy derives tablet replica lookup from a different context. For tablet keyspaces, replica lookup also needs the table name that belongs to the same child statement that supplied the routing key. Teach TokenAwarePolicy to resolve routing metadata from the batch child selected for routing, while preserving the existing single-statement BoundStatement and PreparedStatement behavior. Tests cover bound statement tablet lookup, batch tablet lookup through TokenAwarePolicy, and an integration path where learned tablet metadata is used to route a BatchStatement. Fixes #879 --- .../datastax/driver/core/BatchStatement.java | 28 ++++++--- .../core/policies/TokenAwarePolicy.java | 34 +++++++--- .../com/datastax/driver/core/TabletsIT.java | 23 +++++++ .../core/policies/TokenAwarePolicyTest.java | 63 +++++++++++++++++++ 4 files changed, 131 insertions(+), 17 deletions(-) diff --git a/driver-core/src/main/java/com/datastax/driver/core/BatchStatement.java b/driver-core/src/main/java/com/datastax/driver/core/BatchStatement.java index bda6a6b8d91..692925cabee 100644 --- a/driver-core/src/main/java/com/datastax/driver/core/BatchStatement.java +++ b/driver-core/src/main/java/com/datastax/driver/core/BatchStatement.java @@ -18,6 +18,7 @@ import com.datastax.driver.core.Frame.Header; import com.datastax.driver.core.Requests.QueryFlag; import com.datastax.driver.core.exceptions.UnsupportedFeatureException; +import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -267,13 +268,10 @@ public BatchStatement setSerialConsistencyLevel(ConsistencyLevel serialConsisten @Override public ByteBuffer getRoutingKey(ProtocolVersion protocolVersion, CodecRegistry codecRegistry) { - for (Statement statement : statements) { - if (statement instanceof StatementWrapper) - statement = ((StatementWrapper) statement).getWrappedStatement(); - ByteBuffer rk = statement.getRoutingKey(protocolVersion, codecRegistry); - if (rk != null) return rk; - } - return null; + Statement routingStatement = getRoutingStatement(protocolVersion, codecRegistry); + return routingStatement == null + ? null + : routingStatement.getRoutingKey(protocolVersion, codecRegistry); } @Override @@ -298,6 +296,22 @@ void ensureAllSet() { if (statement instanceof BoundStatement) ((BoundStatement) statement).ensureAllSet(); } + /** + * Returns the first statement in this batch that provides a routing key for the given protocol + * version and codec registry. + */ + @Beta + public Statement getRoutingStatement( + ProtocolVersion protocolVersion, CodecRegistry codecRegistry) { + for (Statement statement : statements) { + if (statement instanceof StatementWrapper) + statement = ((StatementWrapper) statement).getWrappedStatement(); + ByteBuffer rk = statement.getRoutingKey(protocolVersion, codecRegistry); + if (rk != null) return statement; + } + return null; + } + static class IdAndValues { public final List ids; diff --git a/driver-core/src/main/java/com/datastax/driver/core/policies/TokenAwarePolicy.java b/driver-core/src/main/java/com/datastax/driver/core/policies/TokenAwarePolicy.java index 30eb76ba49b..019494906c1 100644 --- a/driver-core/src/main/java/com/datastax/driver/core/policies/TokenAwarePolicy.java +++ b/driver-core/src/main/java/com/datastax/driver/core/policies/TokenAwarePolicy.java @@ -21,6 +21,7 @@ */ package com.datastax.driver.core.policies; +import com.datastax.driver.core.BatchStatement; import com.datastax.driver.core.BoundStatement; import com.datastax.driver.core.Cluster; import com.datastax.driver.core.CodecRegistry; @@ -429,16 +430,7 @@ public Iterator newQueryPlan(final String loggedKeyspace, final Statement if (partitionKey == null || keyspace == null) return childPolicy.newQueryPlan(keyspace, statement); - String tableName = null; - ColumnDefinitions defs = null; - if (statement instanceof BoundStatement) { - defs = ((BoundStatement) statement).preparedStatement().getVariables(); - } else if (statement instanceof PreparedStatement) { - defs = ((PreparedStatement) statement).getVariables(); - } - if (defs != null && defs.size() > 0) { - tableName = defs.getTable(0); - } + String tableName = getRoutingTable(statement); final List replicas = clusterMetadata.getReplicasList( @@ -453,6 +445,28 @@ public Iterator newQueryPlan(final String loggedKeyspace, final Statement } } + private String getRoutingTable(Statement statement) { + ColumnDefinitions defs = getRoutingVariables(statement); + return (defs == null || defs.size() == 0) ? null : defs.getTable(0); + } + + private ColumnDefinitions getRoutingVariables(Statement statement) { + Statement target = statement; + if (statement instanceof BatchStatement) { + target = ((BatchStatement) statement).getRoutingStatement(protocolVersion, codecRegistry); + if (target == null) { + return null; + } + } + + if (target instanceof BoundStatement) { + return ((BoundStatement) target).preparedStatement().getVariables(); + } else if (target instanceof PreparedStatement) { + return ((PreparedStatement) target).getVariables(); + } + return null; + } + private QueryOptions.RequestRoutingMethod getRequestRouting(Statement statement) { if (!statement.isLWT() || defaultLwtRequestRoutingMethod == null) { return QueryOptions.RequestRoutingMethod.REGULAR; diff --git a/driver-core/src/test/java/com/datastax/driver/core/TabletsIT.java b/driver-core/src/test/java/com/datastax/driver/core/TabletsIT.java index f6583cc2215..f3ecdd362c0 100644 --- a/driver-core/src/test/java/com/datastax/driver/core/TabletsIT.java +++ b/driver-core/src/test/java/com/datastax/driver/core/TabletsIT.java @@ -298,6 +298,29 @@ public void should_receive_each_tablet_exactly_once() { } } + @Test(groups = "short") + public void batch_statement_should_deliver_tablet_info_and_route_properly() { + prepareCluster(); + Session session = newSession(); + try { + session + .getCluster() + .getMetadata() + .getTabletMap() + .removeTableMappings(KEYSPACE_NAME.toLowerCase()); + + PreparedStatement preparedStatement = session.prepare(STMT_INSERT); + Assert.assertTrue( + executeOnAllHostsAndReturnIfResultHasTabletsInfo(session, preparedStatement.bind(2, 2))); + Assert.assertTrue(waitSessionLearnedTabletInfo(session)); + + BatchStatement routedBatch = new BatchStatement().add(preparedStatement.bind(2, 2)); + Assert.assertTrue(checkIfRoutedProperly(session, routedBatch)); + } finally { + session.close(); + } + } + private static boolean waitSessionLearnedTabletInfo(Session session) { if (isSessionLearnedTabletInfo(session)) { return true; diff --git a/driver-core/src/test/java/com/datastax/driver/core/policies/TokenAwarePolicyTest.java b/driver-core/src/test/java/com/datastax/driver/core/policies/TokenAwarePolicyTest.java index eeb57e91b2b..3860ea92ef7 100644 --- a/driver-core/src/test/java/com/datastax/driver/core/policies/TokenAwarePolicyTest.java +++ b/driver-core/src/test/java/com/datastax/driver/core/policies/TokenAwarePolicyTest.java @@ -29,12 +29,16 @@ import static com.datastax.driver.core.policies.TokenAwarePolicy.ReplicaOrdering.TOPOLOGICAL; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.datastax.driver.core.BatchStatement; import com.datastax.driver.core.BoundStatement; import com.datastax.driver.core.CCMBridge; import com.datastax.driver.core.Cluster; import com.datastax.driver.core.CodecRegistry; +import com.datastax.driver.core.ColumnDefinitions; import com.datastax.driver.core.Configuration; import com.datastax.driver.core.Host; import com.datastax.driver.core.HostDistance; @@ -155,6 +159,65 @@ public void should_create_random_order() { assertThat(queryPlan).containsOnlyOnce(host1, host2, host3, host4).endsWith(host4, host3); } + @Test(groups = "unit") + public void should_use_table_name_from_bound_statement_for_tablet_routing() { + // given + BoundStatement bound = newBoundStatement("tablets_table", routingKey); + when(metadata.getReplicasList(Metadata.quote(KEYSPACE), "tablets_table", null, routingKey)) + .thenReturn(Lists.newArrayList(host1, host2)); + when(childPolicy.newQueryPlan(KEYSPACE, bound)) + .thenReturn(Lists.newArrayList(host4, host3, host2, host1).iterator()); + + TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, TOPOLOGICAL); + policy.init(cluster, null); + + // when + Iterator queryPlan = policy.newQueryPlan(KEYSPACE, bound); + + // then + assertThat(queryPlan).containsExactly(host1, host2, host4, host3); + verify(metadata).getReplicasList(Metadata.quote(KEYSPACE), "tablets_table", null, routingKey); + } + + @Test(groups = "unit") + public void should_use_table_name_from_routed_statement_in_batch_for_tablet_routing() { + // given + BoundStatement skippedBound = newBoundStatement("ignored_table", null); + BoundStatement routedBound = newBoundStatement("tablets_table", routingKey); + + BatchStatement batch = new BatchStatement().add(skippedBound).add(routedBound); + when(metadata.getReplicasList(Metadata.quote(KEYSPACE), "tablets_table", null, routingKey)) + .thenReturn(Lists.newArrayList(host1, host2)); + when(childPolicy.newQueryPlan(KEYSPACE, batch)) + .thenReturn(Lists.newArrayList(host4, host3, host2, host1).iterator()); + + TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, TOPOLOGICAL); + policy.init(cluster, null); + + // when + Iterator queryPlan = policy.newQueryPlan(KEYSPACE, batch); + + // then + assertThat(queryPlan).containsExactly(host1, host2, host4, host3); + verify(metadata).getReplicasList(Metadata.quote(KEYSPACE), "tablets_table", null, routingKey); + verify(metadata, never()) + .getReplicasList(Metadata.quote(KEYSPACE), "ignored_table", null, routingKey); + } + + private BoundStatement newBoundStatement(String table, ByteBuffer routingKey) { + BoundStatement bound = mock(BoundStatement.class); + PreparedStatement prepared = mock(PreparedStatement.class); + ColumnDefinitions variables = mock(ColumnDefinitions.class); + when(bound.getKeyspace()).thenReturn(KEYSPACE); + when(bound.getRoutingKey(any(ProtocolVersion.class), any(CodecRegistry.class))) + .thenReturn(routingKey); + when(bound.preparedStatement()).thenReturn(prepared); + when(prepared.getVariables()).thenReturn(variables); + when(variables.size()).thenReturn(1); + when(variables.getTable(0)).thenReturn(table); + return bound; + } + @Test(groups = "unit", dataProvider = "shuffleProvider") public void should_prioritize_local_replicas_for_lwt(TokenAwarePolicy.ReplicaOrdering ordering) { // given