diff --git a/driver-core/src/main/java/com/datastax/driver/core/BatchStatement.java b/driver-core/src/main/java/com/datastax/driver/core/BatchStatement.java index bda6a6b8d91..692925cabee 100644 --- a/driver-core/src/main/java/com/datastax/driver/core/BatchStatement.java +++ b/driver-core/src/main/java/com/datastax/driver/core/BatchStatement.java @@ -18,6 +18,7 @@ import com.datastax.driver.core.Frame.Header; import com.datastax.driver.core.Requests.QueryFlag; import com.datastax.driver.core.exceptions.UnsupportedFeatureException; +import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -267,13 +268,10 @@ public BatchStatement setSerialConsistencyLevel(ConsistencyLevel serialConsisten @Override public ByteBuffer getRoutingKey(ProtocolVersion protocolVersion, CodecRegistry codecRegistry) { - for (Statement statement : statements) { - if (statement instanceof StatementWrapper) - statement = ((StatementWrapper) statement).getWrappedStatement(); - ByteBuffer rk = statement.getRoutingKey(protocolVersion, codecRegistry); - if (rk != null) return rk; - } - return null; + Statement routingStatement = getRoutingStatement(protocolVersion, codecRegistry); + return routingStatement == null + ? null + : routingStatement.getRoutingKey(protocolVersion, codecRegistry); } @Override @@ -298,6 +296,22 @@ void ensureAllSet() { if (statement instanceof BoundStatement) ((BoundStatement) statement).ensureAllSet(); } + /** + * Returns the first statement in this batch that provides a routing key for the given protocol + * version and codec registry. + */ + @Beta + public Statement getRoutingStatement( + ProtocolVersion protocolVersion, CodecRegistry codecRegistry) { + for (Statement statement : statements) { + if (statement instanceof StatementWrapper) + statement = ((StatementWrapper) statement).getWrappedStatement(); + ByteBuffer rk = statement.getRoutingKey(protocolVersion, codecRegistry); + if (rk != null) return statement; + } + return null; + } + static class IdAndValues { public final List ids; diff --git a/driver-core/src/main/java/com/datastax/driver/core/policies/TokenAwarePolicy.java b/driver-core/src/main/java/com/datastax/driver/core/policies/TokenAwarePolicy.java index 30eb76ba49b..019494906c1 100644 --- a/driver-core/src/main/java/com/datastax/driver/core/policies/TokenAwarePolicy.java +++ b/driver-core/src/main/java/com/datastax/driver/core/policies/TokenAwarePolicy.java @@ -21,6 +21,7 @@ */ package com.datastax.driver.core.policies; +import com.datastax.driver.core.BatchStatement; import com.datastax.driver.core.BoundStatement; import com.datastax.driver.core.Cluster; import com.datastax.driver.core.CodecRegistry; @@ -429,16 +430,7 @@ public Iterator newQueryPlan(final String loggedKeyspace, final Statement if (partitionKey == null || keyspace == null) return childPolicy.newQueryPlan(keyspace, statement); - String tableName = null; - ColumnDefinitions defs = null; - if (statement instanceof BoundStatement) { - defs = ((BoundStatement) statement).preparedStatement().getVariables(); - } else if (statement instanceof PreparedStatement) { - defs = ((PreparedStatement) statement).getVariables(); - } - if (defs != null && defs.size() > 0) { - tableName = defs.getTable(0); - } + String tableName = getRoutingTable(statement); final List replicas = clusterMetadata.getReplicasList( @@ -453,6 +445,28 @@ public Iterator newQueryPlan(final String loggedKeyspace, final Statement } } + private String getRoutingTable(Statement statement) { + ColumnDefinitions defs = getRoutingVariables(statement); + return (defs == null || defs.size() == 0) ? null : defs.getTable(0); + } + + private ColumnDefinitions getRoutingVariables(Statement statement) { + Statement target = statement; + if (statement instanceof BatchStatement) { + target = ((BatchStatement) statement).getRoutingStatement(protocolVersion, codecRegistry); + if (target == null) { + return null; + } + } + + if (target instanceof BoundStatement) { + return ((BoundStatement) target).preparedStatement().getVariables(); + } else if (target instanceof PreparedStatement) { + return ((PreparedStatement) target).getVariables(); + } + return null; + } + private QueryOptions.RequestRoutingMethod getRequestRouting(Statement statement) { if (!statement.isLWT() || defaultLwtRequestRoutingMethod == null) { return QueryOptions.RequestRoutingMethod.REGULAR; diff --git a/driver-core/src/test/java/com/datastax/driver/core/TabletsIT.java b/driver-core/src/test/java/com/datastax/driver/core/TabletsIT.java index f6583cc2215..f3ecdd362c0 100644 --- a/driver-core/src/test/java/com/datastax/driver/core/TabletsIT.java +++ b/driver-core/src/test/java/com/datastax/driver/core/TabletsIT.java @@ -298,6 +298,29 @@ public void should_receive_each_tablet_exactly_once() { } } + @Test(groups = "short") + public void batch_statement_should_deliver_tablet_info_and_route_properly() { + prepareCluster(); + Session session = newSession(); + try { + session + .getCluster() + .getMetadata() + .getTabletMap() + .removeTableMappings(KEYSPACE_NAME.toLowerCase()); + + PreparedStatement preparedStatement = session.prepare(STMT_INSERT); + Assert.assertTrue( + executeOnAllHostsAndReturnIfResultHasTabletsInfo(session, preparedStatement.bind(2, 2))); + Assert.assertTrue(waitSessionLearnedTabletInfo(session)); + + BatchStatement routedBatch = new BatchStatement().add(preparedStatement.bind(2, 2)); + Assert.assertTrue(checkIfRoutedProperly(session, routedBatch)); + } finally { + session.close(); + } + } + private static boolean waitSessionLearnedTabletInfo(Session session) { if (isSessionLearnedTabletInfo(session)) { return true; diff --git a/driver-core/src/test/java/com/datastax/driver/core/policies/TokenAwarePolicyTest.java b/driver-core/src/test/java/com/datastax/driver/core/policies/TokenAwarePolicyTest.java index eeb57e91b2b..3860ea92ef7 100644 --- a/driver-core/src/test/java/com/datastax/driver/core/policies/TokenAwarePolicyTest.java +++ b/driver-core/src/test/java/com/datastax/driver/core/policies/TokenAwarePolicyTest.java @@ -29,12 +29,16 @@ import static com.datastax.driver.core.policies.TokenAwarePolicy.ReplicaOrdering.TOPOLOGICAL; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.datastax.driver.core.BatchStatement; import com.datastax.driver.core.BoundStatement; import com.datastax.driver.core.CCMBridge; import com.datastax.driver.core.Cluster; import com.datastax.driver.core.CodecRegistry; +import com.datastax.driver.core.ColumnDefinitions; import com.datastax.driver.core.Configuration; import com.datastax.driver.core.Host; import com.datastax.driver.core.HostDistance; @@ -155,6 +159,65 @@ public void should_create_random_order() { assertThat(queryPlan).containsOnlyOnce(host1, host2, host3, host4).endsWith(host4, host3); } + @Test(groups = "unit") + public void should_use_table_name_from_bound_statement_for_tablet_routing() { + // given + BoundStatement bound = newBoundStatement("tablets_table", routingKey); + when(metadata.getReplicasList(Metadata.quote(KEYSPACE), "tablets_table", null, routingKey)) + .thenReturn(Lists.newArrayList(host1, host2)); + when(childPolicy.newQueryPlan(KEYSPACE, bound)) + .thenReturn(Lists.newArrayList(host4, host3, host2, host1).iterator()); + + TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, TOPOLOGICAL); + policy.init(cluster, null); + + // when + Iterator queryPlan = policy.newQueryPlan(KEYSPACE, bound); + + // then + assertThat(queryPlan).containsExactly(host1, host2, host4, host3); + verify(metadata).getReplicasList(Metadata.quote(KEYSPACE), "tablets_table", null, routingKey); + } + + @Test(groups = "unit") + public void should_use_table_name_from_routed_statement_in_batch_for_tablet_routing() { + // given + BoundStatement skippedBound = newBoundStatement("ignored_table", null); + BoundStatement routedBound = newBoundStatement("tablets_table", routingKey); + + BatchStatement batch = new BatchStatement().add(skippedBound).add(routedBound); + when(metadata.getReplicasList(Metadata.quote(KEYSPACE), "tablets_table", null, routingKey)) + .thenReturn(Lists.newArrayList(host1, host2)); + when(childPolicy.newQueryPlan(KEYSPACE, batch)) + .thenReturn(Lists.newArrayList(host4, host3, host2, host1).iterator()); + + TokenAwarePolicy policy = new TokenAwarePolicy(childPolicy, TOPOLOGICAL); + policy.init(cluster, null); + + // when + Iterator queryPlan = policy.newQueryPlan(KEYSPACE, batch); + + // then + assertThat(queryPlan).containsExactly(host1, host2, host4, host3); + verify(metadata).getReplicasList(Metadata.quote(KEYSPACE), "tablets_table", null, routingKey); + verify(metadata, never()) + .getReplicasList(Metadata.quote(KEYSPACE), "ignored_table", null, routingKey); + } + + private BoundStatement newBoundStatement(String table, ByteBuffer routingKey) { + BoundStatement bound = mock(BoundStatement.class); + PreparedStatement prepared = mock(PreparedStatement.class); + ColumnDefinitions variables = mock(ColumnDefinitions.class); + when(bound.getKeyspace()).thenReturn(KEYSPACE); + when(bound.getRoutingKey(any(ProtocolVersion.class), any(CodecRegistry.class))) + .thenReturn(routingKey); + when(bound.preparedStatement()).thenReturn(prepared); + when(prepared.getVariables()).thenReturn(variables); + when(variables.size()).thenReturn(1); + when(variables.getTable(0)).thenReturn(table); + return bound; + } + @Test(groups = "unit", dataProvider = "shuffleProvider") public void should_prioritize_local_replicas_for_lwt(TokenAwarePolicy.ReplicaOrdering ordering) { // given