From cb519e6bac0689fe65849abbd9fdfd270349bfc4 Mon Sep 17 00:00:00 2001 From: Mikita Hradovich Date: Fri, 23 Jan 2026 14:34:19 +0100 Subject: [PATCH 1/4] feat: introduce RequestRoutingType and RequestRoutingMethod enums - Added RequestRoutingType and RequestRoutingMethod enums to define request routing strategies. - Updated DefaultLoadBalancingPolicy to consider request routing type for replica selection, especially for LWT requests. - Updated various graph statement classes (BytecodeGraphStatement, DefaultBatchGraphStatement, DefaultFluentGraphStatement, DefaultScriptGraphStatement) to implement getRequestRoutingType method. - Modified BatchStatementBuilder to set request routing type based on LWT status. - Enhanced DefaultBatchStatement, DefaultBoundStatement, DefaultPrepareRequest, and DefaultSimpleStatement to include routing type in constructors and methods. - Added logic to avoid slow replicas based on health checks and in-flight requests. --- .../core/graph/BytecodeGraphStatement.java | 8 + .../graph/DefaultBatchGraphStatement.java | 7 + .../graph/DefaultFluentGraphStatement.java | 7 + .../graph/DefaultScriptGraphStatement.java | 7 + .../driver/api/core/RequestRoutingMethod.java | 7 + .../driver/api/core/RequestRoutingType.java | 6 + .../api/core/cql/BatchStatementBuilder.java | 5 +- .../oss/driver/api/core/session/Request.java | 19 ++ .../internal/core/cql/CqlRequestHandler.java | 30 --- .../core/cql/DefaultBatchStatement.java | 87 ++++-- .../core/cql/DefaultBoundStatement.java | 91 +++++-- .../core/cql/DefaultPrepareRequest.java | 7 + .../core/cql/DefaultSimpleStatement.java | 99 +++++-- .../DefaultLoadBalancingPolicy.java | 251 +++++++++++------- .../example/guava/internal/KeyRequest.java | 7 + 15 files changed, 441 insertions(+), 197 deletions(-) create mode 100644 core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingMethod.java create mode 100644 core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingType.java diff --git a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/BytecodeGraphStatement.java b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/BytecodeGraphStatement.java index b6fe05a987c..e8e4554e81f 100644 --- a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/BytecodeGraphStatement.java +++ b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/BytecodeGraphStatement.java @@ -19,9 +19,11 @@ import com.datastax.dse.driver.api.core.graph.FluentGraphStatement; import com.datastax.oss.driver.api.core.ConsistencyLevel; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.Statement; import com.datastax.oss.driver.api.core.metadata.Node; +import edu.umd.cs.findbugs.annotations.NonNull; import java.nio.ByteBuffer; import java.time.Duration; import java.util.Collections; @@ -127,4 +129,10 @@ protected BytecodeGraphStatement newInstance( readConsistencyLevel, writeConsistencyLevel); } + + @NonNull + @Override + public RequestRoutingType getRequestRoutingType() { + return RequestRoutingType.REGULAR; + } } diff --git a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultBatchGraphStatement.java b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultBatchGraphStatement.java index e16287c415d..632d45e61d6 100644 --- a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultBatchGraphStatement.java +++ b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultBatchGraphStatement.java @@ -19,6 +19,7 @@ import com.datastax.dse.driver.api.core.graph.BatchGraphStatement; import com.datastax.oss.driver.api.core.ConsistencyLevel; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; @@ -151,4 +152,10 @@ protected BatchGraphStatement newInstance( public Iterator iterator() { return this.traversals.iterator(); } + + @NonNull + @Override + public RequestRoutingType getRequestRoutingType() { + return RequestRoutingType.REGULAR; + } } diff --git a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultFluentGraphStatement.java b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultFluentGraphStatement.java index 0f6f1faabbf..44fa9e41853 100644 --- a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultFluentGraphStatement.java +++ b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultFluentGraphStatement.java @@ -19,6 +19,7 @@ import com.datastax.dse.driver.api.core.graph.FluentGraphStatement; import com.datastax.oss.driver.api.core.ConsistencyLevel; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.Node; import edu.umd.cs.findbugs.annotations.NonNull; @@ -103,4 +104,10 @@ protected FluentGraphStatement newInstance( public GraphTraversal getTraversal() { return traversal; } + + @NonNull + @Override + public RequestRoutingType getRequestRoutingType() { + return RequestRoutingType.REGULAR; + } } diff --git a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultScriptGraphStatement.java b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultScriptGraphStatement.java index 71f79134237..587e1221b41 100644 --- a/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultScriptGraphStatement.java +++ b/core/src/main/java/com/datastax/dse/driver/internal/core/graph/DefaultScriptGraphStatement.java @@ -19,6 +19,7 @@ import com.datastax.dse.driver.api.core.graph.ScriptGraphStatement; import com.datastax.oss.driver.api.core.ConsistencyLevel; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.protocol.internal.util.collection.NullAllowingImmutableMap; @@ -204,4 +205,10 @@ protected ScriptGraphStatement newInstance( public String toString() { return String.format("ScriptGraphStatement['%s', params: %s]", this.script, this.queryParams); } + + @NonNull + @Override + public RequestRoutingType getRequestRoutingType() { + return RequestRoutingType.REGULAR; + } } diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingMethod.java b/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingMethod.java new file mode 100644 index 00000000000..205f40b1408 --- /dev/null +++ b/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingMethod.java @@ -0,0 +1,7 @@ +package com.datastax.oss.driver.api.core; + +public enum RequestRoutingMethod { + REGULAR, + PRESERVE_REPLICA_ORDER, + TOKEN_BASED_REPLICA_SHUFFLING +} diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingType.java b/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingType.java new file mode 100644 index 00000000000..d8f6d6b9d68 --- /dev/null +++ b/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingType.java @@ -0,0 +1,6 @@ +package com.datastax.oss.driver.api.core; + +public enum RequestRoutingType { + REGULAR, + LWT +} diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java index 26e0aef8ca1..abf3ef0892e 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java @@ -18,6 +18,7 @@ package com.datastax.oss.driver.api.core.cql; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.internal.core.cql.DefaultBatchStatement; import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; import com.datastax.oss.driver.shaded.guava.common.collect.Iterables; @@ -152,6 +153,8 @@ public BatchStatementBuilder clearStatements() { @NonNull public BatchStatement build() { List> statements = statementsBuilder.build(); + RequestRoutingType routingType = + isLWT != null ? (isLWT ? RequestRoutingType.LWT : RequestRoutingType.REGULAR) : null; return new DefaultBatchStatement( batchType, statements, @@ -172,7 +175,7 @@ public BatchStatement build() { timeout, node, nowInSeconds, - isLWT); + routingType); } public int getStatementsCount() { diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java b/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java index 92c25e146c7..99486e6585c 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java @@ -25,6 +25,8 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.DefaultProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingMethod; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverConfig; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; @@ -199,4 +201,21 @@ default Partitioner getPartitioner() { /** @return The node configured on this statement, or null if none is configured. */ @Nullable Node getNode(); + + /** + * Returns the routing type for this request. + * + *

The value represents how the request is handled on the server side (for example, regular vs + * lightweight transaction). Load balancing policies use this signal to shape the execution plan + * (eligible coordinators and ordering). + * + * @return The routing type configured on this request + */ + @NonNull + RequestRoutingType getRequestRoutingType(); + + @Nullable + default RequestRoutingMethod getRoutingMethod() { + return RequestRoutingMethod.REGULAR; + } } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandler.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandler.java index 80eece271a8..4008dd528f0 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandler.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandler.java @@ -97,11 +97,9 @@ import java.util.List; import java.util.Map; import java.util.Queue; -import java.util.Set; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; -import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -208,14 +206,6 @@ public void onThrottleReady(boolean wasDelayed) { Queue queryPlan; if (this.initialStatement.getNode() != null) { queryPlan = new SimpleQueryPlan(this.initialStatement.getNode()); - } else if (this.initialStatement.isLWT()) { - queryPlan = - getReplicas( - session.getKeyspace().orElse(null), - this.initialStatement, - context - .getLoadBalancingPolicyWrapper() - .newQueryPlan(initialStatement, executionProfile.getName(), session)); } else { queryPlan = context @@ -226,26 +216,6 @@ public void onThrottleReady(boolean wasDelayed) { sendRequest(initialStatement, null, queryPlan, 0, 0, true); } - private Queue getReplicas( - CqlIdentifier loggedKeyspace, Statement statement, Queue fallback) { - Token routingToken = getRoutingToken(statement); - CqlIdentifier keyspace = statement.getKeyspace(); - if (keyspace == null) { - keyspace = statement.getRoutingKeyspace(); - if (keyspace == null) { - keyspace = loggedKeyspace; - } - } - - TokenMap tokenMap = context.getMetadataManager().getMetadata().getTokenMap().orElse(null); - if (routingToken == null || keyspace == null || tokenMap == null) { - return fallback; - } - - Set replicas = tokenMap.getReplicas(keyspace, routingToken); - return new ConcurrentLinkedQueue<>(replicas); - } - public CompletionStage handle() { return result; } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java index c8cb5b7a084..38bc3af89b7 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java @@ -25,6 +25,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.BatchStatement; import com.datastax.oss.driver.api.core.cql.BatchType; @@ -69,7 +70,7 @@ public class DefaultBatchStatement implements BatchStatement { private final Duration timeout; private final Node node; private final int nowInSeconds; - private final Boolean isLWT; + private final RequestRoutingType routingType; public DefaultBatchStatement( BatchType batchType, @@ -91,7 +92,7 @@ public DefaultBatchStatement( Duration timeout, Node node, int nowInSeconds, - Boolean isLWT) { + RequestRoutingType routingType) { for (BatchableStatement statement : statements) { if (statement != null && (statement.getConsistencyLevel() != null @@ -123,7 +124,7 @@ public DefaultBatchStatement( this.timeout = timeout; this.node = node; this.nowInSeconds = nowInSeconds; - this.isLWT = isLWT; + this.routingType = routingType; } @NonNull @@ -155,7 +156,7 @@ public BatchStatement setBatchType(@NonNull BatchType newBatchType) { timeout, node, nowInSeconds, - isLWT); + routingType); } @NonNull @@ -181,7 +182,7 @@ public BatchStatement setKeyspace(@Nullable CqlIdentifier newKeyspace) { timeout, node, nowInSeconds, - isLWT); + routingType); } @NonNull @@ -211,7 +212,7 @@ public BatchStatement add(@NonNull BatchableStatement statement) { timeout, node, nowInSeconds, - isLWT); + routingType); } } @@ -245,7 +246,7 @@ public BatchStatement addAll(@NonNull Iterable> timeout, node, nowInSeconds, - isLWT); + routingType); } } @@ -277,7 +278,7 @@ public BatchStatement clear() { timeout, node, nowInSeconds, - isLWT); + routingType); } @NonNull @@ -314,7 +315,7 @@ public BatchStatement setPagingState(ByteBuffer newPagingState) { timeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -345,7 +346,7 @@ public BatchStatement setPageSize(int newPageSize) { timeout, node, nowInSeconds, - isLWT); + routingType); } @Nullable @@ -377,7 +378,7 @@ public BatchStatement setConsistencyLevel(@Nullable ConsistencyLevel newConsiste timeout, node, nowInSeconds, - isLWT); + routingType); } @Nullable @@ -410,7 +411,7 @@ public BatchStatement setSerialConsistencyLevel( timeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -441,7 +442,7 @@ public BatchStatement setExecutionProfileName(@Nullable String newConfigProfileN timeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -472,7 +473,7 @@ public DefaultBatchStatement setExecutionProfile(@Nullable DriverExecutionProfil timeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -538,7 +539,7 @@ public BatchStatement setRoutingKeyspace(CqlIdentifier newRoutingKeyspace) { timeout, node, nowInSeconds, - isLWT); + routingType); } @NonNull @@ -564,7 +565,7 @@ public BatchStatement setNode(@Nullable Node newNode) { timeout, newNode, nowInSeconds, - isLWT); + routingType); } @Nullable @@ -611,7 +612,7 @@ public BatchStatement setRoutingKey(ByteBuffer newRoutingKey) { timeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -652,7 +653,7 @@ public BatchStatement setRoutingToken(Token newRoutingToken) { timeout, node, nowInSeconds, - isLWT); + routingType); } @NonNull @@ -684,7 +685,7 @@ public DefaultBatchStatement setCustomPayload(@NonNull Map n timeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -721,7 +722,7 @@ public DefaultBatchStatement setIdempotent(Boolean newIdempotence) { timeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -752,7 +753,7 @@ public BatchStatement setTracing(boolean newTracing) { timeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -783,7 +784,7 @@ public BatchStatement setQueryTimestamp(long newTimestamp) { timeout, node, nowInSeconds, - isLWT); + routingType); } @NonNull @@ -809,7 +810,7 @@ public BatchStatement setTimeout(@Nullable Duration newTimeout) { newTimeout, node, nowInSeconds, - isLWT); + routingType); } @Override @@ -840,12 +841,46 @@ public BatchStatement setNowInSeconds(int newNowInSeconds) { timeout, node, newNowInSeconds, - isLWT); + routingType); + } + + @NonNull + @Override + public RequestRoutingType getRequestRoutingType() { + return routingType; + } + + @NonNull + @Override + public BatchStatement setRequestRoutingType(RequestRoutingType requestRoutingType) { + return new DefaultBatchStatement( + batchType, + statements, + executionProfileName, + executionProfile, + keyspace, + routingKeyspace, + routingKey, + routingToken, + customPayload, + idempotent, + tracing, + timestamp, + pagingState, + pageSize, + consistencyLevel, + serialConsistencyLevel, + timeout, + node, + nowInSeconds, + requestRoutingType); } @NonNull @Override public BatchStatement setIsLWT(Boolean newIsLWT) { + RequestRoutingType routingType = + newIsLWT != null ? (newIsLWT ? RequestRoutingType.LWT : RequestRoutingType.REGULAR) : null; return new DefaultBatchStatement( batchType, statements, @@ -866,12 +901,12 @@ public BatchStatement setIsLWT(Boolean newIsLWT) { timeout, node, nowInSeconds, - newIsLWT); + routingType); } @Override public boolean isLWT() { - if (isLWT != null) return isLWT; + if (routingType != null) return routingType == RequestRoutingType.LWT; return statements.stream().anyMatch(Statement::isLWT); } } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java index 05673692ce9..c60ec4dba6a 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java @@ -26,6 +26,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.BoundStatement; import com.datastax.oss.driver.api.core.cql.ColumnDefinitions; @@ -69,6 +70,7 @@ public class DefaultBoundStatement implements BoundStatement { private final ProtocolVersion protocolVersion; private final Node node; private final int nowInSeconds; + private final RequestRoutingType routingType; public DefaultBoundStatement( PreparedStatement preparedStatement, @@ -91,7 +93,8 @@ public DefaultBoundStatement( CodecRegistry codecRegistry, ProtocolVersion protocolVersion, Node node, - int nowInSeconds) { + int nowInSeconds, + RequestRoutingType routingType) { this.preparedStatement = preparedStatement; this.variableDefinitions = variableDefinitions; this.values = values; @@ -113,6 +116,7 @@ public DefaultBoundStatement( this.protocolVersion = protocolVersion; this.node = node; this.nowInSeconds = nowInSeconds; + this.routingType = routingType; } @Override @@ -207,7 +211,8 @@ public BoundStatement setBytesUnsafe(int i, ByteBuffer v) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + routingType); } @NonNull @@ -251,7 +256,8 @@ public BoundStatement setExecutionProfileName(@Nullable String newConfigProfileN codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + routingType); } @Override @@ -283,7 +289,8 @@ public BoundStatement setExecutionProfile(@Nullable DriverExecutionProfile newPr codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + routingType); } @Override @@ -333,7 +340,8 @@ public BoundStatement setRoutingKeyspace(@Nullable CqlIdentifier newRoutingKeysp codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + routingType); } @NonNull @@ -360,7 +368,8 @@ public BoundStatement setNode(@Nullable Node newNode) { codecRegistry, protocolVersion, newNode, - nowInSeconds); + nowInSeconds, + routingType); } @Nullable @@ -420,7 +429,8 @@ public BoundStatement setRoutingKey(@Nullable ByteBuffer newRoutingKey) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + routingType); } @Override @@ -452,7 +462,8 @@ public BoundStatement setRoutingToken(@Nullable Token newRoutingToken) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + routingType); } @NonNull @@ -485,7 +496,8 @@ public BoundStatement setCustomPayload(@NonNull Map newCusto codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + routingType); } @Override @@ -517,7 +529,8 @@ public BoundStatement setIdempotent(@Nullable Boolean newIdempotence) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + routingType); } @Override @@ -549,7 +562,8 @@ public BoundStatement setTracing(boolean newTracing) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + routingType); } @Override @@ -581,7 +595,8 @@ public BoundStatement setQueryTimestamp(long newTimestamp) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + routingType); } @Nullable @@ -614,7 +629,8 @@ public BoundStatement setTimeout(@Nullable Duration newTimeout) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + routingType); } @Override @@ -646,7 +662,8 @@ public BoundStatement setPagingState(@Nullable ByteBuffer newPagingState) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + routingType); } @Override @@ -678,7 +695,8 @@ public BoundStatement setPageSize(int newPageSize) { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + routingType); } @Nullable @@ -711,7 +729,8 @@ public BoundStatement setConsistencyLevel(@Nullable ConsistencyLevel newConsiste codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + routingType); } @Nullable @@ -745,7 +764,8 @@ public BoundStatement setSerialConsistencyLevel( codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + routingType); } @Override @@ -777,7 +797,42 @@ public BoundStatement setNowInSeconds(int newNowInSeconds) { codecRegistry, protocolVersion, node, - newNowInSeconds); + newNowInSeconds, + routingType); + } + + @NonNull + @Override + public RequestRoutingType getRequestRoutingType() { + return routingType; + } + + @NonNull + @Override + public BoundStatement setRequestRoutingType(@Nullable RequestRoutingType requestRoutingType) { + return new DefaultBoundStatement( + preparedStatement, + variableDefinitions, + values, + executionProfileName, + executionProfile, + routingKeyspace, + routingKey, + routingToken, + customPayload, + idempotent, + tracing, + timestamp, + pagingState, + pageSize, + consistencyLevel, + serialConsistencyLevel, + timeout, + codecRegistry, + protocolVersion, + node, + nowInSeconds, + requestRoutingType); } @Override diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPrepareRequest.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPrepareRequest.java index 7f87dbe5b51..019b56dbb1f 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPrepareRequest.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPrepareRequest.java @@ -20,6 +20,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.PrepareRequest; import com.datastax.oss.driver.api.core.cql.SimpleStatement; @@ -197,6 +198,12 @@ public Node getNode() { return null; } + @NonNull + @Override + public RequestRoutingType getRequestRoutingType() { + return RequestRoutingType.REGULAR; + } + @Override public boolean areBoundStatementsTracing() { return statement.isTracing(); diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java index 0af32b988fe..6a157b3dfc6 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java @@ -25,6 +25,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.SimpleStatement; import com.datastax.oss.driver.api.core.metadata.Node; @@ -64,6 +65,7 @@ public class DefaultSimpleStatement implements SimpleStatement { private final Duration timeout; private final Node node; private final int nowInSeconds; + private final RequestRoutingType requestRoutingType; /** @see SimpleStatement#builder(String) */ public DefaultSimpleStatement( @@ -86,7 +88,8 @@ public DefaultSimpleStatement( ConsistencyLevel serialConsistencyLevel, Duration timeout, Node node, - int nowInSeconds) { + int nowInSeconds, + RequestRoutingType requestRoutingType) { if (!positionalValues.isEmpty() && !namedValues.isEmpty()) { throw new IllegalArgumentException("Can't have both positional and named values"); } @@ -110,6 +113,7 @@ public DefaultSimpleStatement( this.timeout = timeout; this.node = node; this.nowInSeconds = nowInSeconds; + this.requestRoutingType = requestRoutingType; } @NonNull @@ -141,7 +145,8 @@ public SimpleStatement setQuery(@NonNull String newQuery) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @NonNull @@ -173,7 +178,8 @@ public SimpleStatement setPositionalValues(@NonNull List newPositionalVa serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @NonNull @@ -205,7 +211,8 @@ public SimpleStatement setNamedValuesWithIds(@NonNull Map serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -237,7 +244,8 @@ public SimpleStatement setExecutionProfileName(@Nullable String newConfigProfile serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -269,7 +277,8 @@ public SimpleStatement setExecutionProfile(@Nullable DriverExecutionProfile newP serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -301,7 +310,8 @@ public SimpleStatement setKeyspace(@Nullable CqlIdentifier newKeyspace) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -333,7 +343,8 @@ public SimpleStatement setRoutingKeyspace(@Nullable CqlIdentifier newRoutingKeys serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @NonNull @@ -359,7 +370,8 @@ public SimpleStatement setNode(@Nullable Node newNode) { serialConsistencyLevel, timeout, newNode, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -397,7 +409,8 @@ public SimpleStatement setRoutingKey(@Nullable ByteBuffer newRoutingKey) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -429,7 +442,8 @@ public SimpleStatement setRoutingToken(@Nullable Token newRoutingToken) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @NonNull @@ -461,7 +475,8 @@ public SimpleStatement setCustomPayload(@NonNull Map newCust serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -493,7 +508,8 @@ public SimpleStatement setIdempotent(@Nullable Boolean newIdempotence) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -524,7 +540,8 @@ public SimpleStatement setTracing(boolean newTracing) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -555,7 +572,8 @@ public SimpleStatement setQueryTimestamp(long newTimestamp) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -587,7 +605,8 @@ public SimpleStatement setTimeout(@Nullable Duration newTimeout) { serialConsistencyLevel, newTimeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -619,7 +638,8 @@ public SimpleStatement setPagingState(@Nullable ByteBuffer newPagingState) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -650,7 +670,8 @@ public SimpleStatement setPageSize(int newPageSize) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -682,7 +703,8 @@ public SimpleStatement setConsistencyLevel(@Nullable ConsistencyLevel newConsist serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Nullable @@ -715,7 +737,8 @@ public SimpleStatement setSerialConsistencyLevel( newSerialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } @Override @@ -746,7 +769,41 @@ public SimpleStatement setNowInSeconds(int newNowInSeconds) { serialConsistencyLevel, timeout, node, - newNowInSeconds); + newNowInSeconds, + requestRoutingType); + } + + @NonNull + @Override + public RequestRoutingType getRequestRoutingType() { + return requestRoutingType; + } + + @NonNull + @Override + public SimpleStatement setRequestRoutingType(RequestRoutingType requestRoutingType) { + return new DefaultSimpleStatement( + query, + positionalValues, + namedValues, + executionProfileName, + executionProfile, + keyspace, + routingKeyspace, + routingKey, + routingToken, + customPayload, + idempotent, + tracing, + timestamp, + pagingState, + pageSize, + consistencyLevel, + serialConsistencyLevel, + timeout, + node, + nowInSeconds, + requestRoutingType); } @Override diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java index 0b78141227a..67ae3dabeb9 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java @@ -20,6 +20,8 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MINUTES; +import com.datastax.oss.driver.api.core.RequestRoutingMethod; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.context.DriverContext; @@ -49,7 +51,9 @@ import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicLongArray; +import java.util.stream.Collectors; import net.jcip.annotations.ThreadSafe; +import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -131,113 +135,32 @@ protected Optional discoverLocalDc(@NonNull Map nodes) { @NonNull @Override public Queue newQueryPlan(@Nullable Request request, @Nullable Session session) { - if (!avoidSlowReplicas) { - return super.newQueryPlan(request, session); + List replicas = getReplicas(request, session); + RequestRoutingType requestType = + Objects.nonNull(request) ? request.getRequestRoutingType() : RequestRoutingType.REGULAR; + boolean isLWT = requestType == RequestRoutingType.LWT; + Object[] currentNodes = + isLWT + ? getReplicasFromLocalDcForLwt(replicas) + : getLiveNodes().dc(getLocalDatacenter()).toArray(); + + if (Objects.nonNull(request) + && request.getRoutingMethod() == RequestRoutingMethod.PRESERVE_REPLICA_ORDER) { + return new SimpleQueryPlan(currentNodes); } - // Take a snapshot since the set is concurrent: - Object[] currentNodes = getLiveNodes().dc(getLocalDatacenter()).toArray(); - - List allReplicas = getReplicas(request, session); int replicaCount = 0; // in currentNodes - int localRackReplicaCount = 0; // in currentNodes - String localRack = getLocalRack(); - - if (!allReplicas.isEmpty()) { - - // Move replicas to the beginning of the plan - // Replicas in local rack should precede other replicas - for (int i = 0; i < currentNodes.length; i++) { - Node node = (Node) currentNodes[i]; - if (allReplicas.contains(node)) { - if (Objects.equals(node.getRack(), localRack) - && Objects.equals(node.getDatacenter(), getLocalDatacenter())) { - ArrayUtils.bubbleUp(currentNodes, i, localRackReplicaCount); - localRackReplicaCount++; - } else { - ArrayUtils.bubbleUp(currentNodes, i, replicaCount); - } - replicaCount++; - } - } + if (!replicas.isEmpty()) { + Pair counts = moveReplicasToFront(requestType, currentNodes, replicas); + replicaCount = counts.getLeft(); + int localRackReplicaCount = counts.getRight(); // in currentNodes if (replicaCount > 1) { - if (localRack != null && localRackReplicaCount > 0) { - // Shuffle only replicas that are in the local rack - shuffleHead(currentNodes, localRackReplicaCount); - // Shuffles only replicas that are not in local rack - shuffleInRange(currentNodes, localRackReplicaCount, replicaCount - 1); - } else { - shuffleHead(currentNodes, replicaCount); - } + shuffleLocalRackReplicasAndReplicas( + requestType, currentNodes, replicaCount, localRackReplicaCount); - if (replicaCount > 2) { - - assert session != null; - - // Test replicas health - Node newestUpReplica = null; - BitSet unhealthyReplicas = null; // bit mask storing indices of unhealthy replicas - long mostRecentUpTimeNanos = -1; - long now = nanoTime(); - for (int i = 0; i < replicaCount; i++) { - Node node = (Node) currentNodes[i]; - assert node != null; - Long upTimeNanos = upTimes.get(node); - if (upTimeNanos != null - && now - upTimeNanos - NEWLY_UP_INTERVAL_NANOS < 0 - && upTimeNanos - mostRecentUpTimeNanos > 0) { - newestUpReplica = node; - mostRecentUpTimeNanos = upTimeNanos; - } - if (newestUpReplica == null && isUnhealthy(node, session, now)) { - if (unhealthyReplicas == null) { - unhealthyReplicas = new BitSet(replicaCount); - } - unhealthyReplicas.set(i); - } - } - - // When: - // - there isn't any newly UP replica and - // - there is one or more unhealthy replicas and - // - there is a majority of healthy replicas - int unhealthyReplicasCount = - unhealthyReplicas == null ? 0 : unhealthyReplicas.cardinality(); - if (newestUpReplica == null - && unhealthyReplicasCount > 0 - && unhealthyReplicasCount < (replicaCount / 2.0)) { - - // Reorder the unhealthy replicas to the back of the list - // Start from the back of the replicas, then move backwards; - // stop once all unhealthy replicas are moved to the back. - int counter = 0; - for (int i = replicaCount - 1; i >= 0 && counter < unhealthyReplicasCount; i--) { - if (unhealthyReplicas.get(i)) { - ArrayUtils.bubbleDown(currentNodes, i, replicaCount - 1 - counter); - counter++; - } - } - } - - // When: - // - there is a newly UP replica and - // - the replica in first or second position is the most recent replica marked as UP and - // - dice roll 1d4 != 1 - else if ((newestUpReplica == currentNodes[0] || newestUpReplica == currentNodes[1]) - && diceRoll1d4() != 1) { - - // Send it to the back of the replicas - ArrayUtils.bubbleDown( - currentNodes, newestUpReplica == currentNodes[0] ? 0 : 1, replicaCount - 1); - } - - // Reorder the first two replicas in the shuffled list based on the number of - // in-flight requests - if (getInFlight((Node) currentNodes[0], session) - > getInFlight((Node) currentNodes[1], session)) { - ArrayUtils.swap(currentNodes, 0, 1); - } + if (replicaCount > 2 && avoidSlowReplicas) { + avoidSlowReplicas(Objects.requireNonNull(session), currentNodes, replicaCount); } } } @@ -255,6 +178,132 @@ > getInFlight((Node) currentNodes[1], session)) { return maybeAddDcFailover(request, plan); } + /** For LWT requests, prefer replicas in the local DC to avoid cross-DC coordination */ + private Object[] getReplicasFromLocalDcForLwt(List replicas) { + // For LWT requests, start from replicas; if a local DC is configured, prefer replicas + // in the local DC to avoid cross-DC coordination. Preserve original replica order. + String localDc = getLocalDatacenter(); + if (localDc != null) { + List filtered = + replicas.stream() + .filter(n -> Objects.equals(n.getDatacenter(), localDc)) + .collect(Collectors.toList()); + // Fallback to all replicas if none are in the local DC + if (!filtered.isEmpty()) { + return filtered.toArray(); + } + } + return replicas.toArray(); + } + + private Pair moveReplicasToFront( + RequestRoutingType routingType, Object[] currentNodes, List allReplicas) { + // Note: local rack prioritization is intentionally ignored for LWT requests to prevent + // congestion when different loaders from different racks target distinct rack-local LWT + // leaders. + int replicaCount = 0, localRackReplicaCount = 0; + for (int i = 0; i < currentNodes.length; i++) { + Node node = (Node) currentNodes[i]; + if (allReplicas.contains(node)) { + if (Objects.equals(node.getRack(), getLocalRack()) + && Objects.equals(node.getDatacenter(), getLocalDatacenter()) + && routingType != RequestRoutingType.LWT) { + ArrayUtils.bubbleUp(currentNodes, i, localRackReplicaCount); + localRackReplicaCount++; + } else { + ArrayUtils.bubbleUp(currentNodes, i, replicaCount); + } + replicaCount++; + } + } + return Pair.of(replicaCount, localRackReplicaCount); + } + + private void shuffleLocalRackReplicasAndReplicas( + RequestRoutingType routingType, + Object[] currentNodes, + int replicaCount, + int localRackReplicaCount) { + // For LWT, ignore local rack prioritization to avoid rack-local leader congestion; treat + // all local-DC replicas uniformly. + if (routingType != RequestRoutingType.LWT + && getLocalRack() != null + && localRackReplicaCount > 0) { + // Shuffle only replicas that are in the local rack + shuffleHead(currentNodes, localRackReplicaCount); + // Shuffles only replicas that are not in local rack + shuffleInRange(currentNodes, localRackReplicaCount, replicaCount - 1); + } else { + shuffleHead(currentNodes, replicaCount); + } + } + + private void avoidSlowReplicas( + @NonNull Session session, Object[] currentNodes, int replicaCount) { + // Test replicas health + Node newestUpReplica = null; + BitSet unhealthyReplicas = null; // bit mask storing indices of unhealthy replicas + long mostRecentUpTimeNanos = -1; + long now = nanoTime(); + for (int i = 0; i < replicaCount; i++) { + Node node = (Node) currentNodes[i]; + assert node != null; + Long upTimeNanos = upTimes.get(node); + if (upTimeNanos != null + && now - upTimeNanos - NEWLY_UP_INTERVAL_NANOS < 0 + && upTimeNanos - mostRecentUpTimeNanos > 0) { + newestUpReplica = node; + mostRecentUpTimeNanos = upTimeNanos; + } + if (newestUpReplica == null && isUnhealthy(node, session, now)) { + if (unhealthyReplicas == null) { + unhealthyReplicas = new BitSet(replicaCount); + } + unhealthyReplicas.set(i); + } + } + + // When: + // - there isn't any newly UP replica and + // - there is one or more unhealthy replicas and + // - there is a majority of healthy replicas + int unhealthyReplicasCount = unhealthyReplicas == null ? 0 : unhealthyReplicas.cardinality(); + if (newestUpReplica == null + && unhealthyReplicasCount > 0 + && unhealthyReplicasCount < (replicaCount / 2.0)) { + + // Reorder the unhealthy replicas to the back of the list + // Start from the back of the replicas, then move backwards; + // stop once all unhealthy replicas are moved to the back. + int counter = 0; + for (int i = replicaCount - 1; i >= 0 && counter < unhealthyReplicasCount; i--) { + if (unhealthyReplicas.get(i)) { + ArrayUtils.bubbleDown(currentNodes, i, replicaCount - 1 - counter); + counter++; + } + } + } + + // When: + // - there is a newly UP replica and + // - the replica in first or second position is the most recent replica marked as UP and + // - dice roll 1d4 != 1 + else if ((newestUpReplica == currentNodes[0] || newestUpReplica == currentNodes[1]) + && diceRoll1d4() != 1) { + + // Send it to the back of the replicas + ArrayUtils.bubbleDown( + currentNodes, newestUpReplica == currentNodes[0] ? 0 : 1, replicaCount - 1); + } + + // Reorder the first two replicas in the shuffled list based on the number of + // in-flight requests + if (getInFlight((Node) currentNodes[0], session) + > getInFlight((Node) currentNodes[1], session)) { + ArrayUtils.swap(currentNodes, 0, 1); + } + } + @Override public void onNodeSuccess( @NonNull Request request, diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/example/guava/internal/KeyRequest.java b/integration-tests/src/test/java/com/datastax/oss/driver/example/guava/internal/KeyRequest.java index ef582cce1b9..3c15ed4db52 100644 --- a/integration-tests/src/test/java/com/datastax/oss/driver/example/guava/internal/KeyRequest.java +++ b/integration-tests/src/test/java/com/datastax/oss/driver/example/guava/internal/KeyRequest.java @@ -18,6 +18,7 @@ package com.datastax.oss.driver.example.guava.internal; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.driver.api.core.metadata.token.Token; @@ -94,4 +95,10 @@ public Duration getTimeout() { public Node getNode() { return null; } + + @NonNull + @Override + public RequestRoutingType getRequestRoutingType() { + return RequestRoutingType.REGULAR; + } } From 1d82d27176709a9497cb97137d26ee346a31eaf2 Mon Sep 17 00:00:00 2001 From: Mikita Hradovich Date: Fri, 23 Jan 2026 14:37:06 +0100 Subject: [PATCH 2/4] Add configurable LWT request routing method Introduce RequestRoutingType to distinguish LWT from regular queries and add a new configuration option to control LWT routing behavior. The new `advanced.load-balancing-policy.default-lwt-request-routing-method` option allows choosing between: - REGULAR: Default shuffling and slow replica avoidance - PRESERVE_REPLICA_ORDER: Maintains replica order from partitioner Changes: - Add RequestRoutingType enum (REGULAR, LWT) to classify requests - Remove unused RequestRoutingMethod enum from Request interface - Thread RequestRoutingType through Statement builders and implementations - Update DefaultLoadBalancingPolicy to route LWT queries according to config - Add corresponding TypedDriverOption and OptionsMap support - Update prepared statement creation to detect and mark LWT queries - Remove RequestRoutingMethod.getRoutingMethod() default method This enables optimized LWT performance by avoiding unnecessary shuffling when replica order preservation is beneficial for linearizability. refactor: update LWT request routing method to preserve replica order --- .../driver/api/core/RequestRoutingMethod.java | 7 - .../api/core/config/DefaultDriverOption.java | 14 +- .../driver/api/core/config/OptionsMap.java | 3 + .../api/core/config/TypedDriverOption.java | 6 + .../api/core/cql/BoundStatementBuilder.java | 3 +- .../driver/api/core/cql/SimpleStatement.java | 10 +- .../api/core/cql/SimpleStatementBuilder.java | 3 +- .../oss/driver/api/core/cql/Statement.java | 15 ++ .../driver/api/core/cql/StatementBuilder.java | 9 ++ .../oss/driver/api/core/session/Request.java | 8 +- .../driver/internal/core/cql/Conversions.java | 5 +- .../core/cql/DefaultPreparedStatement.java | 16 +- .../DefaultLoadBalancingPolicy.java | 141 ++++++++++++------ core/src/main/resources/reference.conf | 8 + 14 files changed, 173 insertions(+), 75 deletions(-) delete mode 100644 core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingMethod.java diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingMethod.java b/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingMethod.java deleted file mode 100644 index 205f40b1408..00000000000 --- a/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingMethod.java +++ /dev/null @@ -1,7 +0,0 @@ -package com.datastax.oss.driver.api.core; - -public enum RequestRoutingMethod { - REGULAR, - PRESERVE_REPLICA_ORDER, - TOKEN_BASED_REPLICA_SHUFFLING -} diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java b/core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java index e651b1d999e..9e0119903df 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/config/DefaultDriverOption.java @@ -718,7 +718,7 @@ public enum DefaultDriverOption implements DriverOption { /** * CQL 4.x has a known issue where prepared statement invalidation may be bypassed on the client - * side. Reference: https://github.com/scylladb/scylladb/issues/20860 + * side. Reference: link * *

When this occurs, the client's metadata can become outdated, leading to various * deserialization errors. @@ -1063,7 +1063,17 @@ public enum DefaultDriverOption implements DriverOption { *

Value type: {@link java.util.List List}<{@link String}> */ LOAD_BALANCING_DC_FAILOVER_PREFERRED_REMOTE_DCS( - "advanced.load-balancing-policy.dc-failover.preferred-remote-dcs"); + "advanced.load-balancing-policy.dc-failover.preferred-remote-dcs"), + + /** + * The default routing method to use for LWT (Lightweight Transaction) requests. REGULAR uses the + * standard load balancing algorithm with slow replica avoidance and shuffling. + * PRESERVE_REPLICA_ORDER maintains the replica order from the partitioner. + * + *

Value-type: string + */ + LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD( + "advanced.load-balancing-policy.default-lwt-request-routing-method"); private final String path; diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java b/core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java index ed95389f57b..28559ea8556 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/config/OptionsMap.java @@ -393,6 +393,9 @@ protected static void fillWithDriverDefaults(OptionsMap map) { map.put(TypedDriverOption.METRICS_GENERATE_AGGREGABLE_HISTOGRAMS, true); map.put( TypedDriverOption.LOAD_BALANCING_DC_FAILOVER_PREFERRED_REMOTE_DCS, ImmutableList.of("")); + map.put( + TypedDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD, + "PRESERVE_REPLICA_ORDER"); } @Immutable diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java b/core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java index 1fa752783d8..818468ee9d5 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/config/TypedDriverOption.java @@ -933,6 +933,12 @@ public String toString() { DefaultDriverOption.LOAD_BALANCING_DC_FAILOVER_PREFERRED_REMOTE_DCS, GenericType.listOf(String.class)); + /** The request routing method to use in the request routing load balancing policy. */ + public static final TypedDriverOption LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD = + new TypedDriverOption<>( + DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD, + GenericType.STRING); + private static Iterable> introspectBuiltInValues() { try { ImmutableList.Builder> result = ImmutableList.builder(); diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BoundStatementBuilder.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BoundStatementBuilder.java index 7e8f8723e1b..fbbcccee018 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BoundStatementBuilder.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BoundStatementBuilder.java @@ -204,6 +204,7 @@ public BoundStatement build() { codecRegistry, protocolVersion, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } } diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/SimpleStatement.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/SimpleStatement.java index ef04cd14a5b..20f17fa716e 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/SimpleStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/SimpleStatement.java @@ -20,6 +20,7 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.DefaultProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.context.DriverContext; import com.datastax.oss.driver.api.core.session.Request; import com.datastax.oss.driver.internal.core.cql.DefaultSimpleStatement; @@ -84,7 +85,8 @@ static SimpleStatement newInstance(@NonNull String cqlQuery) { null, null, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + RequestRoutingType.REGULAR); } /** @@ -118,7 +120,8 @@ static SimpleStatement newInstance( null, null, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + RequestRoutingType.REGULAR); } /** @@ -149,7 +152,8 @@ static SimpleStatement newInstance( null, null, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + RequestRoutingType.REGULAR); } /** diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/SimpleStatementBuilder.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/SimpleStatementBuilder.java index 1ac910ff6a7..38deffe404c 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/SimpleStatementBuilder.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/SimpleStatementBuilder.java @@ -185,6 +185,7 @@ public SimpleStatement build() { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + requestRoutingType); } } diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java index 464a0a92a53..68edc3a71a6 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java @@ -27,6 +27,7 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.NoNodeAvailableException; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.context.DriverContext; @@ -525,6 +526,20 @@ default SelfT setNowInSeconds(int nowInSeconds) { return (SelfT) this; } + /** + * Sets the request routing type to use when applying the request (for testing purposes). + * + *

This method's default implementation returns the statement unchanged. The only reason it + * exists is to preserve binary compatibility. Internally, the driver overrides it to record the + * new value. + */ + @NonNull + @CheckReturnValue + @SuppressWarnings("unchecked") + default SelfT setRequestRoutingType(RequestRoutingType requestRoutingType) { + return (SelfT) this; + } + /** * Informs if this is a prepared LWT query. * diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/StatementBuilder.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/StatementBuilder.java index 531070b854c..a7247542ea3 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/StatementBuilder.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/StatementBuilder.java @@ -19,6 +19,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.driver.api.core.metadata.token.Token; @@ -61,6 +62,7 @@ public abstract class StatementBuilder< @Nullable protected Duration timeout; @Nullable protected Node node; protected int nowInSeconds = Statement.NO_NOW_IN_SECONDS; + protected RequestRoutingType requestRoutingType = RequestRoutingType.REGULAR; protected StatementBuilder() { // nothing to do @@ -87,6 +89,7 @@ protected StatementBuilder(StatementT template) { this.timeout = template.getTimeout(); this.node = template.getNode(); this.nowInSeconds = template.getNowInSeconds(); + this.requestRoutingType = template.getRequestRoutingType(); } /** @see Statement#setExecutionProfileName(String) */ @@ -282,6 +285,12 @@ public SelfT setNowInSeconds(int nowInSeconds) { return self; } + /** @see Statement#setRequestRoutingType(RequestRoutingType) */ + public SelfT setRequestRoutingType(RequestRoutingType routingType) { + this.requestRoutingType = routingType; + return self; + } + @NonNull protected Map buildCustomPayload() { return (customPayloadBuilder == null) diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java b/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java index 99486e6585c..e92e3cc6814 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java @@ -25,7 +25,6 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.DefaultProtocolVersion; -import com.datastax.oss.driver.api.core.RequestRoutingMethod; import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverConfig; @@ -103,7 +102,7 @@ public interface Request { * The table to use for tablet-aware routing. Infers the table from available ColumnDefinitions or * {@code null} if it is not possible. * - * @return + * @return The table to use for tablet-aware routing, or {@code null} if not set. */ @Nullable default CqlIdentifier getRoutingTable() { @@ -213,9 +212,4 @@ default Partitioner getPartitioner() { */ @NonNull RequestRoutingType getRequestRoutingType(); - - @Nullable - default RequestRoutingMethod getRoutingMethod() { - return RequestRoutingMethod.REGULAR; - } } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/Conversions.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/Conversions.java index 88f35eb75a0..0a864293b0d 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/Conversions.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/Conversions.java @@ -27,6 +27,7 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverConfig; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; @@ -427,7 +428,9 @@ public static DefaultPreparedStatement toPreparedStatement( request.areBoundStatementsTracing(), context.getCodecRegistry(), context.getProtocolVersion(), - lwtInfo != null && lwtInfo.isLwt(response.variablesMetadata.flags)); + lwtInfo != null && lwtInfo.isLwt(response.variablesMetadata.flags) + ? RequestRoutingType.LWT + : RequestRoutingType.REGULAR); } public static ColumnDefinitions toColumnDefinitions( diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java index dace3647645..6cb2e0b134e 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java @@ -27,6 +27,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.BoundStatement; @@ -82,7 +83,7 @@ public class DefaultPreparedStatement implements PreparedStatement { private final ConsistencyLevel serialConsistencyLevelForBoundStatements; private final Duration timeoutForBoundStatements; private final Partitioner partitioner; - private final boolean isLWT; + private final RequestRoutingType requestRoutingType; private volatile boolean skipMetadata; public DefaultPreparedStatement( @@ -110,7 +111,7 @@ public DefaultPreparedStatement( boolean areBoundStatementsTracing, CodecRegistry codecRegistry, ProtocolVersion protocolVersion, - boolean isLWT) { + RequestRoutingType requestRoutingType) { this.id = id; this.partitionKeyIndices = partitionKeyIndices; // It's important that we keep a reference to this object, so that it only gets evicted from @@ -136,7 +137,7 @@ public DefaultPreparedStatement( this.codecRegistry = codecRegistry; this.protocolVersion = protocolVersion; - this.isLWT = isLWT; + this.requestRoutingType = requestRoutingType; this.skipMetadata = resolveSkipMetadata( query, resultMetadataId, resultSetDefinitions, this.executionProfileForBoundStatements); @@ -188,7 +189,7 @@ public ColumnDefinitions getResultSetDefinitions() { @Override public boolean isLWT() { - return isLWT; + return requestRoutingType == RequestRoutingType.LWT; } @Override @@ -229,7 +230,8 @@ public BoundStatement bind(@NonNull Object... values) { codecRegistry, protocolVersion, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + requestRoutingType); } @NonNull @@ -263,8 +265,8 @@ public RepreparePayload getRepreparePayload() { } private static class ResultMetadata { - private ByteBuffer resultMetadataId; - private ColumnDefinitions resultSetDefinitions; + private final ByteBuffer resultMetadataId; + private final ColumnDefinitions resultSetDefinitions; private ResultMetadata(ByteBuffer resultMetadataId, ColumnDefinitions resultSetDefinitions) { this.resultMetadataId = resultMetadataId; diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java index 67ae3dabeb9..fbe2a6fd9e9 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java @@ -20,7 +20,6 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MINUTES; -import com.datastax.oss.driver.api.core.RequestRoutingMethod; import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; @@ -51,7 +50,6 @@ import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicLongArray; -import java.util.stream.Collectors; import net.jcip.annotations.ThreadSafe; import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; @@ -72,7 +70,7 @@ * } * * - * See {@code reference.conf} (in the manual or core driver JAR) for more details. + *

See {@code reference.conf} (in the manual or core driver JAR) for more details. * *

Local datacenter: This implementation requires a local datacenter to be defined, * otherwise it will throw an {@link IllegalStateException}. A local datacenter can be supplied @@ -99,6 +97,11 @@ @ThreadSafe public class DefaultLoadBalancingPolicy extends BasicLoadBalancingPolicy implements RequestTracker { + public enum RequestRoutingMethod { + REGULAR, + PRESERVE_REPLICA_ORDER + } + private static final Logger LOG = LoggerFactory.getLogger(DefaultLoadBalancingPolicy.class); private static final long NEWLY_UP_INTERVAL_NANOS = MINUTES.toNanos(1); @@ -108,14 +111,31 @@ public class DefaultLoadBalancingPolicy extends BasicLoadBalancingPolicy impleme protected final ConcurrentMap responseTimes; protected final Map upTimes = new ConcurrentHashMap<>(); private final boolean avoidSlowReplicas; + private final RequestRoutingMethod lwtRequestRoutingMethod; public DefaultLoadBalancingPolicy(@NonNull DriverContext context, @NonNull String profileName) { super(context, profileName); this.avoidSlowReplicas = profile.getBoolean(DefaultDriverOption.LOAD_BALANCING_POLICY_SLOW_AVOIDANCE, true); + this.lwtRequestRoutingMethod = getRequestRoutingMethod(); this.responseTimes = new MapMaker().weakKeys().makeMap(); } + @NonNull + private RequestRoutingMethod getRequestRoutingMethod() { + String methodString = + profile.getString(DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD); + try { + return RequestRoutingMethod.valueOf(methodString.toUpperCase()); + } catch (IllegalArgumentException e) { + LOG.warn( + "[{}] Unknown request routing method '{}', defaulting to PRESERVE_REPLICA_ORDER", + logPrefix, + methodString); + return RequestRoutingMethod.PRESERVE_REPLICA_ORDER; + } + } + @NonNull @Override public Optional getRequestTracker() { @@ -132,32 +152,70 @@ protected Optional discoverLocalDc(@NonNull Map nodes) { return new MandatoryLocalDcHelper(context, profile, logPrefix).discoverLocalDc(nodes); } + @NonNull + public RequestRoutingMethod getRequestRoutingMethod(@Nullable Request request) { + if (request == null) { + return RequestRoutingMethod.REGULAR; + } + RequestRoutingType routingType = request.getRequestRoutingType(); + if (routingType == null) { + return RequestRoutingMethod.REGULAR; + } + + switch (routingType) { + case LWT: + return lwtRequestRoutingMethod; + case REGULAR: + return RequestRoutingMethod.REGULAR; + default: + return RequestRoutingMethod.REGULAR; + } + } + @NonNull @Override public Queue newQueryPlan(@Nullable Request request, @Nullable Session session) { + switch (getRequestRoutingMethod(request)) { + case PRESERVE_REPLICA_ORDER: + return newQueryPlanPreserveReplicas(request, session); + default: + return newQueryPlanRegular(request, session); + } + } + + @NonNull + /** + * Builds a query plan that preserves the replica order as returned by the load balancing + * strategy, while pushing non-local replicas after local ones. + */ + public Queue newQueryPlanPreserveReplicas( + @Nullable Request request, @Nullable Session session) { List replicas = getReplicas(request, session); - RequestRoutingType requestType = - Objects.nonNull(request) ? request.getRequestRoutingType() : RequestRoutingType.REGULAR; - boolean isLWT = requestType == RequestRoutingType.LWT; - Object[] currentNodes = - isLWT - ? getReplicasFromLocalDcForLwt(replicas) - : getLiveNodes().dc(getLocalDatacenter()).toArray(); - - if (Objects.nonNull(request) - && request.getRoutingMethod() == RequestRoutingMethod.PRESERVE_REPLICA_ORDER) { - return new SimpleQueryPlan(currentNodes); + String localDc = getLocalDatacenter(); + if (localDc == null || replicas.isEmpty()) { + return new SimpleQueryPlan(replicas.toArray()); } + return new SimpleQueryPlan(moveNonLocalReplicasToTheEnd(replicas, localDc)); + } + + @NonNull + /** + * Builds a query plan that prioritizes local replicas, shuffles them for balance, and then + * round-robins the remaining local nodes. + */ + public Queue newQueryPlanRegular(@Nullable Request request, @Nullable Session session) { + List replicas = getReplicas(request, session); + Object[] currentNodes = getLiveNodes().dc(getLocalDatacenter()).toArray(); int replicaCount = 0; // in currentNodes if (!replicas.isEmpty()) { - Pair counts = moveReplicasToFront(requestType, currentNodes, replicas); + Pair counts = moveReplicasToFront(currentNodes, replicas); replicaCount = counts.getLeft(); + int localRackReplicaCount = counts.getRight(); // in currentNodes if (replicaCount > 1) { - shuffleLocalRackReplicasAndReplicas( - requestType, currentNodes, replicaCount, localRackReplicaCount); + shuffleLocalRackReplicasAndReplicas(currentNodes, replicaCount, localRackReplicaCount); if (replicaCount > 2 && avoidSlowReplicas) { avoidSlowReplicas(Objects.requireNonNull(session), currentNodes, replicaCount); @@ -178,36 +236,34 @@ public Queue newQueryPlan(@Nullable Request request, @Nullable Session ses return maybeAddDcFailover(request, plan); } - /** For LWT requests, prefer replicas in the local DC to avoid cross-DC coordination */ - private Object[] getReplicasFromLocalDcForLwt(List replicas) { - // For LWT requests, start from replicas; if a local DC is configured, prefer replicas - // in the local DC to avoid cross-DC coordination. Preserve original replica order. - String localDc = getLocalDatacenter(); - if (localDc != null) { - List filtered = - replicas.stream() - .filter(n -> Objects.equals(n.getDatacenter(), localDc)) - .collect(Collectors.toList()); - // Fallback to all replicas if none are in the local DC - if (!filtered.isEmpty()) { - return filtered.toArray(); + /** + * Returns a replica array with local-datacenter replicas first and remote replicas preserved at + * the end. + */ + private static Object[] moveNonLocalReplicasToTheEnd(List replicas, String localDc) { + Object[] orderedReplicas = new Object[replicas.size()]; + int index = 0; + for (Node replica : replicas) { + if (Objects.equals(replica.getDatacenter(), localDc)) { + orderedReplicas[index++] = replica; + } + } + for (Node replica : replicas) { + if (!Objects.equals(replica.getDatacenter(), localDc)) { + orderedReplicas[index++] = replica; } } - return replicas.toArray(); + return orderedReplicas; } private Pair moveReplicasToFront( - RequestRoutingType routingType, Object[] currentNodes, List allReplicas) { - // Note: local rack prioritization is intentionally ignored for LWT requests to prevent - // congestion when different loaders from different racks target distinct rack-local LWT - // leaders. + Object[] currentNodes, List allReplicas) { int replicaCount = 0, localRackReplicaCount = 0; for (int i = 0; i < currentNodes.length; i++) { Node node = (Node) currentNodes[i]; if (allReplicas.contains(node)) { if (Objects.equals(node.getRack(), getLocalRack()) - && Objects.equals(node.getDatacenter(), getLocalDatacenter()) - && routingType != RequestRoutingType.LWT) { + && Objects.equals(node.getDatacenter(), getLocalDatacenter())) { ArrayUtils.bubbleUp(currentNodes, i, localRackReplicaCount); localRackReplicaCount++; } else { @@ -220,15 +276,8 @@ private Pair moveReplicasToFront( } private void shuffleLocalRackReplicasAndReplicas( - RequestRoutingType routingType, - Object[] currentNodes, - int replicaCount, - int localRackReplicaCount) { - // For LWT, ignore local rack prioritization to avoid rack-local leader congestion; treat - // all local-DC replicas uniformly. - if (routingType != RequestRoutingType.LWT - && getLocalRack() != null - && localRackReplicaCount > 0) { + Object[] currentNodes, int replicaCount, int localRackReplicaCount) { + if (getLocalRack() != null && localRackReplicaCount > 0) { // Shuffle only replicas that are in the local rack shuffleHead(currentNodes, localRackReplicaCount); // Shuffles only replicas that are not in local rack diff --git a/core/src/main/resources/reference.conf b/core/src/main/resources/reference.conf index 161cd4bc91a..d994390bdf7 100644 --- a/core/src/main/resources/reference.conf +++ b/core/src/main/resources/reference.conf @@ -651,6 +651,14 @@ datastax-java-driver { # Overridable in a profile: no preferred-remote-dcs = [""] } + # The method to use when routing requests. + # Options are: + # - "REGULAR": default behavior of the load balancing policy includes avoiding slow replicas and shuffling nodes + # - "PRESERVE_REPLICA_ORDER": tries to preserve the order of replicas as returned by the partitioner when building the query plan. + # Required: no + # Modifiable at runtime: no + # Overridable in a profile: yes + default-lwt-request-routing-method = "REGULAR" } # Whether to schedule reconnection attempts if all contact points are unreachable on the first From f78df0476d99b76daf961df244a48c4dac001977 Mon Sep 17 00:00:00 2001 From: Mikita Hradovich Date: Fri, 23 Jan 2026 14:45:30 +0100 Subject: [PATCH 3/4] Fix code style and improve test consistency for LWT feature MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix Javadoc positioning: Move @NonNull annotations after doc comments in DefaultLoadBalancingPolicy methods (per Java conventions) - Add missing @NonNull annotation to StatementBuilderTest mock builder - Add @Nullable annotation to NodeStateIT query plan method signature - Standardize test infrastructure: * Add @RunWith(MockitoJUnitRunner.Silent.class) to 7 test classes * Update LoadBalancingPolicyTestBase to stub LWT routing config option * Convert base class from @RunWith to abstract (subclasses now declare runner) - Standardize integration test naming: ccmRule→CCM_RULE, sessionRule→SESSION_RULE - Update test mocks with RequestRoutingType.REGULAR parameter for compatibility - Improve LWT integration tests: * BatchStatementIT: Fix variable references, enhance LWT batch assertions * LWTLoadBalancingIT: Change from single-node to replica-set validation * Add LWTLoadBalancingMultiDcIT: New multi-DC LWT routing test coverage No functional changes to production code—purely code quality and test improvements. Apply suggestions from code review Co-authored-by: Dmitry Kropachev --- .../driver/api/core/cql/StatementBuilder.java | 4 +- .../core/cql/DefaultBatchStatement.java | 4 +- .../core/cql/DefaultBoundStatement.java | 6 +- .../core/cql/DefaultPreparedStatement.java | 4 +- .../core/cql/DefaultSimpleStatement.java | 6 +- .../DefaultLoadBalancingPolicy.java | 21 +- core/src/main/resources/reference.conf | 4 +- .../api/core/cql/StatementBuilderTest.java | 2 + .../api/core/cql/StatementProfileTest.java | 4 +- .../internal/core/cql/StatementSizeTest.java | 4 +- ...BasicLoadBalancingPolicyQueryPlanTest.java | 2 + ...cInferringLoadBalancingPolicyInitTest.java | 3 + .../DefaultLoadBalancingPolicyConfigTest.java | 114 ++++++++ .../DefaultLoadBalancingPolicyInitTest.java | 3 + ...aultLoadBalancingPolicyLwtRoutingTest.java | 253 +++++++++++++++++ ...faultLoadBalancingPolicyQueryPlanTest.java | 71 +++++ ...LoadBalancingPolicyRequestRoutingTest.java | 257 ++++++++++++++++++ ...LoadBalancingPolicyRequestTrackerTest.java | 3 + .../LoadBalancingPolicyTestBase.java | 6 +- .../core/tracker/RequestLogFormatterTest.java | 3 +- .../oss/driver/core/cql/BatchStatementIT.java | 97 ++++--- .../loadbalancing/LWTLoadBalancingIT.java | 25 +- .../LWTLoadBalancingMultiDcIT.java | 209 ++++++++++++++ .../oss/driver/core/metadata/NodeStateIT.java | 3 +- 24 files changed, 1029 insertions(+), 79 deletions(-) create mode 100644 core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyConfigTest.java create mode 100644 core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyLwtRoutingTest.java create mode 100644 core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestRoutingTest.java create mode 100644 integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingMultiDcIT.java diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/StatementBuilder.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/StatementBuilder.java index a7247542ea3..ecfb5f57023 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/StatementBuilder.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/StatementBuilder.java @@ -62,7 +62,7 @@ public abstract class StatementBuilder< @Nullable protected Duration timeout; @Nullable protected Node node; protected int nowInSeconds = Statement.NO_NOW_IN_SECONDS; - protected RequestRoutingType requestRoutingType = RequestRoutingType.REGULAR; + @NonNull protected RequestRoutingType requestRoutingType = RequestRoutingType.REGULAR; protected StatementBuilder() { // nothing to do @@ -286,7 +286,7 @@ public SelfT setNowInSeconds(int nowInSeconds) { } /** @see Statement#setRequestRoutingType(RequestRoutingType) */ - public SelfT setRequestRoutingType(RequestRoutingType routingType) { + public SelfT setRequestRoutingType(@NonNull RequestRoutingType routingType) { this.requestRoutingType = routingType; return self; } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java index 38bc3af89b7..582c326743b 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java @@ -70,7 +70,7 @@ public class DefaultBatchStatement implements BatchStatement { private final Duration timeout; private final Node node; private final int nowInSeconds; - private final RequestRoutingType routingType; + @NonNull private final RequestRoutingType routingType; public DefaultBatchStatement( BatchType batchType, @@ -92,7 +92,7 @@ public DefaultBatchStatement( Duration timeout, Node node, int nowInSeconds, - RequestRoutingType routingType) { + @NonNull RequestRoutingType routingType) { for (BatchableStatement statement : statements) { if (statement != null && (statement.getConsistencyLevel() != null diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java index c60ec4dba6a..c2024dcf8b0 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java @@ -70,7 +70,7 @@ public class DefaultBoundStatement implements BoundStatement { private final ProtocolVersion protocolVersion; private final Node node; private final int nowInSeconds; - private final RequestRoutingType routingType; + @NonNull private final RequestRoutingType routingType; public DefaultBoundStatement( PreparedStatement preparedStatement, @@ -94,7 +94,7 @@ public DefaultBoundStatement( ProtocolVersion protocolVersion, Node node, int nowInSeconds, - RequestRoutingType routingType) { + @NonNull RequestRoutingType routingType) { this.preparedStatement = preparedStatement; this.variableDefinitions = variableDefinitions; this.values = values; @@ -809,7 +809,7 @@ public RequestRoutingType getRequestRoutingType() { @NonNull @Override - public BoundStatement setRequestRoutingType(@Nullable RequestRoutingType requestRoutingType) { + public BoundStatement setRequestRoutingType(@NonNull RequestRoutingType requestRoutingType) { return new DefaultBoundStatement( preparedStatement, variableDefinitions, diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java index 6cb2e0b134e..3994a5683ba 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java @@ -83,7 +83,7 @@ public class DefaultPreparedStatement implements PreparedStatement { private final ConsistencyLevel serialConsistencyLevelForBoundStatements; private final Duration timeoutForBoundStatements; private final Partitioner partitioner; - private final RequestRoutingType requestRoutingType; + @NonNull private final RequestRoutingType requestRoutingType; private volatile boolean skipMetadata; public DefaultPreparedStatement( @@ -111,7 +111,7 @@ public DefaultPreparedStatement( boolean areBoundStatementsTracing, CodecRegistry codecRegistry, ProtocolVersion protocolVersion, - RequestRoutingType requestRoutingType) { + @NonNull RequestRoutingType requestRoutingType) { this.id = id; this.partitionKeyIndices = partitionKeyIndices; // It's important that we keep a reference to this object, so that it only gets evicted from diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java index 6a157b3dfc6..f1a0495d6ed 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java @@ -65,7 +65,7 @@ public class DefaultSimpleStatement implements SimpleStatement { private final Duration timeout; private final Node node; private final int nowInSeconds; - private final RequestRoutingType requestRoutingType; + @NonNull private final RequestRoutingType requestRoutingType; /** @see SimpleStatement#builder(String) */ public DefaultSimpleStatement( @@ -89,7 +89,7 @@ public DefaultSimpleStatement( Duration timeout, Node node, int nowInSeconds, - RequestRoutingType requestRoutingType) { + @NonNull RequestRoutingType requestRoutingType) { if (!positionalValues.isEmpty() && !namedValues.isEmpty()) { throw new IllegalArgumentException("Can't have both positional and named values"); } @@ -781,7 +781,7 @@ public RequestRoutingType getRequestRoutingType() { @NonNull @Override - public SimpleStatement setRequestRoutingType(RequestRoutingType requestRoutingType) { + public SimpleStatement setRequestRoutingType(@NonNull RequestRoutingType requestRoutingType) { return new DefaultSimpleStatement( query, positionalValues, diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java index fbe2a6fd9e9..66a1c13b3eb 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java @@ -20,7 +20,6 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MINUTES; -import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.context.DriverContext; @@ -117,12 +116,12 @@ public DefaultLoadBalancingPolicy(@NonNull DriverContext context, @NonNull Strin super(context, profileName); this.avoidSlowReplicas = profile.getBoolean(DefaultDriverOption.LOAD_BALANCING_POLICY_SLOW_AVOIDANCE, true); - this.lwtRequestRoutingMethod = getRequestRoutingMethod(); + this.lwtRequestRoutingMethod = getDefaultLWTRequestRoutingMethod(); this.responseTimes = new MapMaker().weakKeys().makeMap(); } @NonNull - private RequestRoutingMethod getRequestRoutingMethod() { + private RequestRoutingMethod getDefaultLWTRequestRoutingMethod() { String methodString = profile.getString(DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD); try { @@ -153,20 +152,14 @@ protected Optional discoverLocalDc(@NonNull Map nodes) { } @NonNull - public RequestRoutingMethod getRequestRoutingMethod(@Nullable Request request) { + public RequestRoutingMethod getDefaultLWTRequestRoutingMethod(@Nullable Request request) { if (request == null) { return RequestRoutingMethod.REGULAR; } - RequestRoutingType routingType = request.getRequestRoutingType(); - if (routingType == null) { - return RequestRoutingMethod.REGULAR; - } - - switch (routingType) { + switch (request.getRequestRoutingType()) { case LWT: return lwtRequestRoutingMethod; case REGULAR: - return RequestRoutingMethod.REGULAR; default: return RequestRoutingMethod.REGULAR; } @@ -175,7 +168,7 @@ public RequestRoutingMethod getRequestRoutingMethod(@Nullable Request request) { @NonNull @Override public Queue newQueryPlan(@Nullable Request request, @Nullable Session session) { - switch (getRequestRoutingMethod(request)) { + switch (getDefaultLWTRequestRoutingMethod(request)) { case PRESERVE_REPLICA_ORDER: return newQueryPlanPreserveReplicas(request, session); default: @@ -183,11 +176,11 @@ public Queue newQueryPlan(@Nullable Request request, @Nullable Session ses } } - @NonNull /** * Builds a query plan that preserves the replica order as returned by the load balancing * strategy, while pushing non-local replicas after local ones. */ + @NonNull public Queue newQueryPlanPreserveReplicas( @Nullable Request request, @Nullable Session session) { List replicas = getReplicas(request, session); @@ -199,11 +192,11 @@ public Queue newQueryPlanPreserveReplicas( return new SimpleQueryPlan(moveNonLocalReplicasToTheEnd(replicas, localDc)); } - @NonNull /** * Builds a query plan that prioritizes local replicas, shuffles them for balance, and then * round-robins the remaining local nodes. */ + @NonNull public Queue newQueryPlanRegular(@Nullable Request request, @Nullable Session session) { List replicas = getReplicas(request, session); Object[] currentNodes = getLiveNodes().dc(getLocalDatacenter()).toArray(); diff --git a/core/src/main/resources/reference.conf b/core/src/main/resources/reference.conf index d994390bdf7..40d56d67341 100644 --- a/core/src/main/resources/reference.conf +++ b/core/src/main/resources/reference.conf @@ -654,11 +654,11 @@ datastax-java-driver { # The method to use when routing requests. # Options are: # - "REGULAR": default behavior of the load balancing policy includes avoiding slow replicas and shuffling nodes - # - "PRESERVE_REPLICA_ORDER": tries to preserve the order of replicas as returned by the partitioner when building the query plan. + # - "PRESERVE_REPLICA_ORDER": tries to preserve the order of replicas as returned by the partitioner when building the query plan. When dc is provided, move replicas from non-local dc to the back of query plan, but ignores local rack. # Required: no # Modifiable at runtime: no # Overridable in a profile: yes - default-lwt-request-routing-method = "REGULAR" + default-lwt-request-routing-method = "PRESERVE_REPLICA_ORDER" } # Whether to schedule reconnection attempts if all contact points are unreachable on the first diff --git a/core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementBuilderTest.java b/core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementBuilderTest.java index 9904b1e27d7..a10208645fd 100644 --- a/core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementBuilderTest.java +++ b/core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementBuilderTest.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.when; import com.datastax.oss.driver.shaded.guava.common.base.Charsets; +import edu.umd.cs.findbugs.annotations.NonNull; import java.nio.ByteBuffer; import org.junit.Test; @@ -38,6 +39,7 @@ public MockSimpleStatementBuilder(SimpleStatement template) { super(template); } + @NonNull @Override public SimpleStatement build() { diff --git a/core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementProfileTest.java b/core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementProfileTest.java index af2dccd0432..d59d3a460b9 100644 --- a/core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementProfileTest.java +++ b/core/src/test/java/com/datastax/oss/driver/api/core/cql/StatementProfileTest.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.when; import com.datastax.oss.driver.TestDataProviders; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.internal.core.cql.DefaultBoundStatement; import com.tngtech.java.junit.dataprovider.DataProvider; @@ -191,6 +192,7 @@ private static BoundStatement newBoundStatement() { null, null, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + RequestRoutingType.REGULAR); } } diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/StatementSizeTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/StatementSizeTest.java index dc3ab0702f7..1291e0f8a49 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/StatementSizeTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/StatementSizeTest.java @@ -21,6 +21,7 @@ import static org.mockito.Mockito.when; import com.datastax.oss.driver.api.core.DefaultProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverConfig; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.cql.BatchStatement; @@ -287,6 +288,7 @@ private BoundStatement newBoundStatement( CodecRegistry.DEFAULT, DefaultProtocolVersion.V5, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + RequestRoutingType.REGULAR); } } diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/BasicLoadBalancingPolicyQueryPlanTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/BasicLoadBalancingPolicyQueryPlanTest.java index c2e89cdf07c..428bb5db4f6 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/BasicLoadBalancingPolicyQueryPlanTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/BasicLoadBalancingPolicyQueryPlanTest.java @@ -38,6 +38,7 @@ import static org.mockito.Mockito.when; import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.Metadata; import com.datastax.oss.driver.api.core.metadata.TokenMap; @@ -80,6 +81,7 @@ public void setup() { when(metadataManager.getContactPoints()).thenReturn(ImmutableSet.of(node1)); when(metadataManager.getMetadata()).thenReturn(metadata); when(metadata.getTokenMap()).thenAnswer(invocation -> Optional.of(this.tokenMap)); + when(request.getRequestRoutingType()).thenReturn(RequestRoutingType.REGULAR); policy = createAndInitPolicy(); } diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DcInferringLoadBalancingPolicyInitTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DcInferringLoadBalancingPolicyInitTest.java index 20de3afe9c3..8440fa3bd6b 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DcInferringLoadBalancingPolicyInitTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DcInferringLoadBalancingPolicyInitTest.java @@ -36,7 +36,10 @@ import edu.umd.cs.findbugs.annotations.NonNull; import java.util.UUID; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; +@RunWith(MockitoJUnitRunner.Silent.class) public class DcInferringLoadBalancingPolicyInitTest extends LoadBalancingPolicyTestBase { @Test diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyConfigTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyConfigTest.java new file mode 100644 index 00000000000..768722e0e86 --- /dev/null +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyConfigTest.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright (C) 2020 ScyllaDB + * + * Modified by ScyllaDB + */ +package com.datastax.oss.driver.internal.core.loadbalancing; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; +import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableSet; +import com.tngtech.java.junit.dataprovider.DataProvider; +import com.tngtech.java.junit.dataprovider.DataProviderRunner; +import com.tngtech.java.junit.dataprovider.UseDataProvider; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.MockitoAnnotations; + +@RunWith(DataProviderRunner.class) +public class DefaultLoadBalancingPolicyConfigTest extends LoadBalancingPolicyTestBase { + + @Before + @Override + public void setup() { + MockitoAnnotations.initMocks(this); + super.setup(); + } + + @Test + @DataProvider(value = {"REGULAR", "regular", "PRESERVE_REPLICA_ORDER", "Preserve_Replica_Order"}) + public void should_accept_valid_routing_methods(String routingMethod) { + when(metadataManager.getContactPoints()).thenReturn(ImmutableSet.of(node1)); + + when(defaultProfile.getString( + DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD)) + .thenReturn(routingMethod); + DefaultLoadBalancingPolicy policy = + new DefaultLoadBalancingPolicy(context, DriverExecutionProfile.DEFAULT_NAME); + assertThat(policy).isNotNull(); + } + + @Test + @DataProvider( + value = {"INVALID_METHOD", "", "@#$%^&*()", " REGULAR "}, + trimValues = false) + public void should_default_to_preserve_replica_order_for_invalid_routing_methods( + String invalidValue) { + when(metadataManager.getContactPoints()).thenReturn(ImmutableSet.of(node1)); + + when(defaultProfile.getString( + DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD)) + .thenReturn(invalidValue); + DefaultLoadBalancingPolicy policy = + new DefaultLoadBalancingPolicy(context, DriverExecutionProfile.DEFAULT_NAME); + + assertThat(policy).isNotNull(); + + verify(appender).doAppend(loggingEventCaptor.capture()); + assertThat(loggingEventCaptor.getValue().getFormattedMessage()) + .contains("Unknown request routing method") + .contains("defaulting to PRESERVE_REPLICA_ORDER"); + } + + @Test + @UseDataProvider("configurationCombinations") + public void should_accept_configuration_combinations( + String routingMethod, boolean slowAvoidance) { + when(metadataManager.getContactPoints()).thenReturn(ImmutableSet.of(node1)); + + when(defaultProfile.getString( + DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD)) + .thenReturn(routingMethod); + when(defaultProfile.getBoolean(DefaultDriverOption.LOAD_BALANCING_POLICY_SLOW_AVOIDANCE, true)) + .thenReturn(slowAvoidance); + + DefaultLoadBalancingPolicy policy = + new DefaultLoadBalancingPolicy(context, DriverExecutionProfile.DEFAULT_NAME); + assertThat(policy).isNotNull(); + + verify(defaultProfile, atLeast(1)) + .getBoolean(DefaultDriverOption.LOAD_BALANCING_POLICY_SLOW_AVOIDANCE, true); + } + + @DataProvider + public static Object[][] configurationCombinations() { + return new Object[][] { + {"PRESERVE_REPLICA_ORDER", false}, + {"REGULAR", true} + }; + } +} diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyInitTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyInitTest.java index 77887d627f9..53d9633a23d 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyInitTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyInitTest.java @@ -36,7 +36,10 @@ import edu.umd.cs.findbugs.annotations.NonNull; import java.util.UUID; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; +@RunWith(MockitoJUnitRunner.Silent.class) public class DefaultLoadBalancingPolicyInitTest extends LoadBalancingPolicyTestBase { @Test diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyLwtRoutingTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyLwtRoutingTest.java new file mode 100644 index 00000000000..1e16aafa5f2 --- /dev/null +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyLwtRoutingTest.java @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright (C) 2020 ScyllaDB + * + * Modified by ScyllaDB + */ +package com.datastax.oss.driver.internal.core.loadbalancing; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.when; + +import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; +import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; +import com.datastax.oss.driver.api.core.metadata.Metadata; +import com.datastax.oss.driver.api.core.metadata.Node; +import com.datastax.oss.driver.api.core.metadata.TokenMap; +import com.datastax.oss.driver.api.core.metadata.token.Token; +import com.datastax.oss.driver.api.core.session.Request; +import com.datastax.oss.driver.internal.core.session.DefaultSession; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableMap; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableSet; +import com.datastax.oss.protocol.internal.util.Bytes; +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.Queue; +import java.util.UUID; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.Silent.class) +public class DefaultLoadBalancingPolicyLwtRoutingTest extends LoadBalancingPolicyTestBase { + + private static final CqlIdentifier KEYSPACE = CqlIdentifier.fromInternal("ks"); + private static final ByteBuffer ROUTING_KEY = Bytes.fromHexString("0xdeadbeef"); + + @Mock protected Request request; + @Mock protected DefaultSession session; + @Mock protected Metadata metadata; + @Mock protected TokenMap tokenMap; + @Mock protected Token routingToken; + + private DefaultLoadBalancingPolicy policy; + + @Before + @Override + public void setup() { + super.setup(); + when(metadataManager.getContactPoints()).thenReturn(ImmutableSet.of(node1)); + when(metadataManager.getMetadata()).thenReturn(metadata); + when(metadata.getTokenMap()).thenAnswer(invocation -> Optional.of(this.tokenMap)); + + // Set up nodes with proper DCs + when(node1.getDatacenter()).thenReturn("dc1"); + when(node2.getDatacenter()).thenReturn("dc1"); + when(node3.getDatacenter()).thenReturn("dc1"); + when(node4.getDatacenter()).thenReturn("dc2"); + when(node5.getDatacenter()).thenReturn("dc2"); + + // Configure for PRESERVE_REPLICA_ORDER routing for LWT + when(defaultProfile.getString( + DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD)) + .thenReturn("PRESERVE_REPLICA_ORDER"); + + policy = new DefaultLoadBalancingPolicy(context, DriverExecutionProfile.DEFAULT_NAME); + policy.init( + ImmutableMap.of( + UUID.randomUUID(), node1, + UUID.randomUUID(), node2, + UUID.randomUUID(), node3, + UUID.randomUUID(), node4, + UUID.randomUUID(), node5), + distanceReporter); + } + + @Test + public void should_preserve_replica_order_with_empty_replicas() { + // Given + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)).willReturn(ImmutableList.of()); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then + assertThat(plan).isEmpty(); + } + + @Test + public void should_preserve_replica_order_with_single_local_replica() { + // Given + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node2)); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then + assertThat(plan).containsExactly(node2); + } + + @Test + public void should_preserve_replica_order_with_multiple_local_replicas() { + // Given + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node3, node1, node2)); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then - order preserved exactly as returned from token map + assertThat(plan).containsExactly(node3, node1, node2); + } + + @Test + public void should_push_remote_replicas_to_end() { + // Given + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + // Token map returns replicas in mixed order: remote, local, remote, local + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node4, node1, node5, node2)); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then - local replicas first (preserving their order), remote replicas last (preserving their + // order) + assertThat(plan).containsExactly(node1, node2, node4, node5); + } + + @Test + public void should_preserve_replica_order_with_all_remote_replicas() { + // Given + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node5, node4)); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then - all remote replicas, order preserved + assertThat(plan).containsExactly(node5, node4); + } + + @Test + public void should_handle_null_local_datacenter() { + // Given + when(defaultProfile.isDefined(DefaultDriverOption.LOAD_BALANCING_LOCAL_DATACENTER)) + .thenReturn(false); + + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1, node2)); + + // When - calling with request that might not have local DC set + // The method should handle null localDc gracefully and just return replicas as-is + Queue plan = policy.newQueryPlanPreserveReplicas(request, session); + + // Then - returns all replicas in order when localDc is not defined + assertThat(plan).containsExactly(node1, node2); + } + + @Test + public void should_preserve_order_when_no_routing_key() { + // Given + given(request.getRoutingKeyspace()).willReturn(null); + given(request.getRoutingKey()).willReturn(null); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.REGULAR); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then - with no routing key, no replicas identified, falls back to empty or default behavior + // This tests the edge case where getReplicas returns empty list + assertThat(plan).isNotNull(); + } + + @Test + public void should_dispatch_to_preserve_replicas_when_lwt_and_config_set() { + // Given + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1, node2)); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then - verify it used preserve replica order (no shuffling) + // Call multiple times to ensure order is always preserved (not shuffled) + Queue plan2 = policy.newQueryPlan(request, session); + Queue plan3 = policy.newQueryPlan(request, session); + + assertThat(plan).containsExactly(node1, node2); + assertThat(plan2).containsExactly(node1, node2); + assertThat(plan3).containsExactly(node1, node2); + } + + @Test + public void should_not_add_non_replicas_in_preserve_mode() { + // Given + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + // Only node1 is a replica + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1)); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then - only the replica is in the plan, other live nodes are NOT added + assertThat(plan).containsExactly(node1); + } +} diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyQueryPlanTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyQueryPlanTest.java index f016323c16b..f9445b84d76 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyQueryPlanTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyQueryPlanTest.java @@ -33,6 +33,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.driver.internal.core.pool.ChannelPool; import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; @@ -44,8 +45,11 @@ import java.util.concurrent.atomic.AtomicLongArray; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +@RunWith(MockitoJUnitRunner.Silent.class) public class DefaultLoadBalancingPolicyQueryPlanTest extends BasicLoadBalancingPolicyQueryPlanTest { private static final long T0 = Long.MIN_VALUE; @@ -387,6 +391,73 @@ public void should_prefer_local_rack_replica_with_less_inflight_requests() { assertThat(plan2).containsExactly(node5, node3, node1, node4, node2); } + @Test + public void should_ignore_local_rack_prioritization_for_lwt_requests() { + // Given - LWT request with local rack configured + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1, node3, node5)); + + String localRack = "rack1"; + given(dsePolicy.getLocalRack()).willReturn(localRack); + // Only node1 is in the local rack + given(node1.getRack()).willReturn(localRack); + given(node3.getRack()).willReturn("rack2"); + given(node5.getRack()).willReturn("rack3"); + + given(pool1.getInFlight()).willReturn(0); + given(pool3.getInFlight()).willReturn(0); + given(pool5.getInFlight()).willReturn(0); + + // When + Queue plan1 = dsePolicy.newQueryPlan(request, session); + Queue plan2 = dsePolicy.newQueryPlan(request, session); + + // Then - for LWT requests (RequestRoutingType.LWT) the policy should ignore local rack + // prioritization and preserve the replica order returned by the token map. + // The shuffle methods are still invoked for the non-replica range, so only the non-replica + // nodes (node2 and node4) are permuted between successive plans. + then(dsePolicy).should(times(2)).shuffleHead(any(), anyInt()); + then(dsePolicy).should(times(2)).shuffleInRange(any(), anyInt(), anyInt()); + assertThat(plan1).containsExactly(node1, node3, node5, node2, node4); + assertThat(plan2).containsExactly(node1, node3, node5, node4, node2); + } + + @Test + public void should_respect_local_rack_prioritization_for_regular_requests() { + // Given - REGULAR request (not LWT) with local rack configured + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()) + .willReturn(com.datastax.oss.driver.api.core.RequestRoutingType.REGULAR); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1, node3, node5)); + + String localRack = "rack1"; + given(dsePolicy.getLocalRack()).willReturn(localRack); + // node1 is in the local rack + given(node1.getRack()).willReturn(localRack); + given(node3.getRack()).willReturn("rack2"); + given(node5.getRack()).willReturn("rack3"); + + given(pool1.getInFlight()).willReturn(0); + given(pool3.getInFlight()).willReturn(0); + given(pool5.getInFlight()).willReturn(0); + + // When + Queue plan1 = dsePolicy.newQueryPlan(request, session); + Queue plan2 = dsePolicy.newQueryPlan(request, session); + + // Then - local rack replica prioritized and shuffled separately from others + // Verify that local rack replicas and non-local-rack replicas are shuffled separately + then(dsePolicy).should(times(2)).shuffleHead(any(), anyInt()); + then(dsePolicy).should(times(2)).shuffleInRange(any(), anyInt(), anyInt()); + assertThat(plan1).containsExactly(node1, node3, node5, node2, node4); + assertThat(plan2).containsExactly(node1, node3, node5, node4, node2); + } + @Override protected DefaultLoadBalancingPolicy createAndInitPolicy() { DefaultLoadBalancingPolicy policy = diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestRoutingTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestRoutingTest.java new file mode 100644 index 00000000000..9aef1825329 --- /dev/null +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestRoutingTest.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright (C) 2020 ScyllaDB + * + * Modified by ScyllaDB + */ +package com.datastax.oss.driver.internal.core.loadbalancing; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.when; + +import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.RequestRoutingType; +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; +import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; +import com.datastax.oss.driver.api.core.metadata.Metadata; +import com.datastax.oss.driver.api.core.metadata.Node; +import com.datastax.oss.driver.api.core.metadata.TokenMap; +import com.datastax.oss.driver.api.core.metadata.token.Token; +import com.datastax.oss.driver.api.core.session.Request; +import com.datastax.oss.driver.internal.core.loadbalancing.DefaultLoadBalancingPolicy.RequestRoutingMethod; +import com.datastax.oss.driver.internal.core.session.DefaultSession; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableMap; +import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableSet; +import com.datastax.oss.protocol.internal.util.Bytes; +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.Queue; +import java.util.UUID; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.Silent.class) +public class DefaultLoadBalancingPolicyRequestRoutingTest extends LoadBalancingPolicyTestBase { + + private static final CqlIdentifier KEYSPACE = CqlIdentifier.fromInternal("ks"); + private static final ByteBuffer ROUTING_KEY = Bytes.fromHexString("0xdeadbeef"); + + @Mock protected Request request; + @Mock protected DefaultSession session; + @Mock protected Metadata metadata; + @Mock protected TokenMap tokenMap; + @Mock protected Token routingToken; + + private DefaultLoadBalancingPolicy policy; + + @Before + @Override + public void setup() { + super.setup(); + when(metadataManager.getContactPoints()).thenReturn(ImmutableSet.of(node1)); + when(metadataManager.getMetadata()).thenReturn(metadata); + when(metadata.getTokenMap()).thenAnswer(invocation -> Optional.of(this.tokenMap)); + + when(node1.getDatacenter()).thenReturn("dc1"); + when(node2.getDatacenter()).thenReturn("dc1"); + when(node3.getDatacenter()).thenReturn("dc1"); + } + + private void initPolicy(String routingMethod) { + when(defaultProfile.getString( + DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD)) + .thenReturn(routingMethod); + + policy = new DefaultLoadBalancingPolicy(context, DriverExecutionProfile.DEFAULT_NAME); + policy.init( + ImmutableMap.of( + UUID.randomUUID(), node1, + UUID.randomUUID(), node2, + UUID.randomUUID(), node3), + distanceReporter); + } + + @Test + public void should_return_regular_when_request_is_null() { + // Given + initPolicy("REGULAR"); + + // When + RequestRoutingMethod method = policy.getDefaultLWTRequestRoutingMethod(null); + + // Then + assertThat(method).isEqualTo(RequestRoutingMethod.REGULAR); + } + + @Test + public void should_return_regular_when_routing_type_is_regular() { + // Given + initPolicy("PRESERVE_REPLICA_ORDER"); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.REGULAR); + + // When + RequestRoutingMethod method = policy.getDefaultLWTRequestRoutingMethod(request); + + // Then + assertThat(method).isEqualTo(RequestRoutingMethod.REGULAR); + } + + @Test + public void should_return_regular_for_lwt_when_config_is_regular() { + // Given + initPolicy("REGULAR"); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + + // When + RequestRoutingMethod method = policy.getDefaultLWTRequestRoutingMethod(request); + + // Then + assertThat(method).isEqualTo(RequestRoutingMethod.REGULAR); + } + + @Test + public void should_return_preserve_replica_order_for_lwt_when_config_is_preserve() { + // Given + initPolicy("PRESERVE_REPLICA_ORDER"); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + + // When + RequestRoutingMethod method = policy.getDefaultLWTRequestRoutingMethod(request); + + // Then + assertThat(method).isEqualTo(RequestRoutingMethod.PRESERVE_REPLICA_ORDER); + } + + @Test + public void should_dispatch_to_regular_query_plan_when_request_is_regular() { + // Given + initPolicy("PRESERVE_REPLICA_ORDER"); + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.REGULAR); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1, node2)); + + // When + Queue plan1 = policy.newQueryPlan(request, session); + Queue plan2 = policy.newQueryPlan(request, session); + + // Then - regular routing shuffles replicas (node1, node2), and also adds the local + // non-replica node (node3); order may vary between plans but the same three nodes + // must be present in each plan + assertThat(plan1).containsExactlyInAnyOrder(node1, node2, node3); + assertThat(plan2).containsExactlyInAnyOrder(node1, node2, node3); + } + + @Test + public void should_dispatch_to_preserve_query_plan_when_lwt_and_config_preserve() { + // Given + initPolicy("PRESERVE_REPLICA_ORDER"); + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node2, node1)); + + // When + Queue plan1 = policy.newQueryPlan(request, session); + Queue plan2 = policy.newQueryPlan(request, session); + Queue plan3 = policy.newQueryPlan(request, session); + + // Then - preserve routing maintains exact order + assertThat(plan1).containsExactly(node2, node1); + assertThat(plan2).containsExactly(node2, node1); + assertThat(plan3).containsExactly(node2, node1); + } + + @Test + public void should_dispatch_to_regular_query_plan_when_lwt_but_config_regular() { + // Given + initPolicy("REGULAR"); + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1, node2)); + + // When + Queue plan = policy.newQueryPlan(request, session); + + // Then - uses regular routing which may shuffle and add non-replicas + assertThat(plan).containsExactlyInAnyOrder(node1, node2, node3); + } + + @Test + public void should_handle_null_request_in_new_query_plan() { + // Given + initPolicy("PRESERVE_REPLICA_ORDER"); + + // When + Queue plan = policy.newQueryPlan(null, session); + + // Then - null request should use regular routing + assertThat(plan).isNotNull(); + assertThat(plan).containsExactlyInAnyOrder(node1, node2, node3); + } + + @Test + public void should_use_regular_routing_for_unknown_routing_type() { + // Given + initPolicy("PRESERVE_REPLICA_ORDER"); + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + // Use REGULAR as a stand-in for any "unknown" type - the switch has a default case + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.REGULAR); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1)); + + // When + RequestRoutingMethod method = policy.getDefaultLWTRequestRoutingMethod(request); + + // Then - defaults to REGULAR for any unrecognized type + assertThat(method).isEqualTo(RequestRoutingMethod.REGULAR); + } + + @Test + public void should_consistently_route_same_request_type() { + // Given + initPolicy("PRESERVE_REPLICA_ORDER"); + given(request.getRoutingKeyspace()).willReturn(KEYSPACE); + given(request.getRoutingKey()).willReturn(ROUTING_KEY); + given(request.getRequestRoutingType()).willReturn(RequestRoutingType.LWT); + given(tokenMap.getReplicasList(KEYSPACE, null, ROUTING_KEY)) + .willReturn(ImmutableList.of(node1, node2, node3)); + + // When - call multiple times + RequestRoutingMethod method1 = policy.getDefaultLWTRequestRoutingMethod(request); + RequestRoutingMethod method2 = policy.getDefaultLWTRequestRoutingMethod(request); + RequestRoutingMethod method3 = policy.getDefaultLWTRequestRoutingMethod(request); + + // Then - should always return the same method + assertThat(method1).isEqualTo(RequestRoutingMethod.PRESERVE_REPLICA_ORDER); + assertThat(method2).isEqualTo(RequestRoutingMethod.PRESERVE_REPLICA_ORDER); + assertThat(method3).isEqualTo(RequestRoutingMethod.PRESERVE_REPLICA_ORDER); + } +} diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestTrackerTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestTrackerTest.java index 757af43ef67..aa890778804 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestTrackerTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicyRequestTrackerTest.java @@ -28,8 +28,11 @@ import java.util.UUID; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +@RunWith(MockitoJUnitRunner.Silent.class) public class DefaultLoadBalancingPolicyRequestTrackerTest extends LoadBalancingPolicyTestBase { @Mock Request request; diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/LoadBalancingPolicyTestBase.java b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/LoadBalancingPolicyTestBase.java index c9149efa69f..b301433ed64 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/LoadBalancingPolicyTestBase.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/loadbalancing/LoadBalancingPolicyTestBase.java @@ -35,14 +35,11 @@ import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; import org.junit.After; import org.junit.Before; -import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; import org.slf4j.LoggerFactory; -@RunWith(MockitoJUnitRunner.class) public abstract class LoadBalancingPolicyTestBase { @Mock protected DefaultNode node1; @@ -81,6 +78,9 @@ public void setup() { DefaultDriverOption.LOAD_BALANCING_DC_FAILOVER_ALLOW_FOR_LOCAL_CONSISTENCY_LEVELS)) .thenReturn(false); when(defaultProfile.getString(DefaultDriverOption.REQUEST_CONSISTENCY)).thenReturn("ONE"); + when(defaultProfile.getString( + DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD)) + .thenReturn("REGULAR"); when(context.getMetadataManager()).thenReturn(metadataManager); diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/tracker/RequestLogFormatterTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/tracker/RequestLogFormatterTest.java index c482afe7a47..e175acc267b 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/tracker/RequestLogFormatterTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/tracker/RequestLogFormatterTest.java @@ -28,6 +28,7 @@ import com.datastax.oss.driver.api.core.DefaultProtocolVersion; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.context.DriverContext; import com.datastax.oss.driver.api.core.cql.BatchStatement; import com.datastax.oss.driver.api.core.cql.BoundStatement; @@ -294,6 +295,6 @@ private PreparedStatement mockPreparedStatement(String query, Map sessionRule = SessionRule.builder(ccmRule).build(); + private final SessionRule SESSION_RULE = SessionRule.builder(CCM_RULE).build(); - @Rule public TestRule chain = RuleChain.outerRule(ccmRule).around(sessionRule); + @Rule public TestRule chain = RuleChain.outerRule(CCM_RULE).around(SESSION_RULE); @Rule public TestName name = new TestName(); @@ -89,11 +90,11 @@ public void createTable() { SchemaChangeSynchronizer.withLock( () -> { for (String schemaStatement : schemaStatements) { - sessionRule + SESSION_RULE .session() .execute( SimpleStatement.newInstance(schemaStatement) - .setExecutionProfile(sessionRule.slowProfile())); + .setExecutionProfile(SESSION_RULE.slowProfile())); } }); } @@ -103,7 +104,7 @@ public void should_issue_log_warn_if_batched_statement_have_consistency_level_se SimpleStatement simpleStatement = SimpleStatement.builder("INSERT INTO test (k0, k1, v) values ('123123', ?, ?)").build(); - try (CqlSession session = SessionUtils.newSession(ccmRule, sessionRule.keyspace())) { + try (CqlSession session = SessionUtils.newSession(CCM_RULE, SESSION_RULE.keyspace())) { PreparedStatement prep = session.prepare(simpleStatement); BatchStatementBuilder batch = BatchStatement.builder(DefaultBatchType.UNLOGGED); batch.addStatement(prep.bind(1, 2).setConsistencyLevel(ConsistencyLevel.QUORUM)); @@ -139,7 +140,7 @@ public void should_execute_batch_of_simple_statements_with_variables() { } BatchStatement batchStatement = builder.build(); - sessionRule.session().execute(batchStatement); + SESSION_RULE.session().execute(batchStatement); verifyBatchInsert(); } @@ -154,14 +155,14 @@ public void should_execute_batch_of_bound_statements_with_variables() { String.format( "INSERT INTO test (k0, k1, v) values ('%s', ? , ?)", name.getMethodName())) .build(); - PreparedStatement preparedStatement = sessionRule.session().prepare(insert); + PreparedStatement preparedStatement = SESSION_RULE.session().prepare(insert); for (int i = 0; i < batchCount; i++) { builder.addStatement(preparedStatement.bind(i, i + 1)); } BatchStatement batchStatement = builder.build(); - sessionRule.session().execute(batchStatement); + SESSION_RULE.session().execute(batchStatement); verifyBatchInsert(); } @@ -178,14 +179,14 @@ public void should_execute_batch_of_bound_statements_with_unset_values() { String.format( "INSERT INTO test (k0, k1, v) values ('%s', ? , ?)", name.getMethodName())) .build(); - PreparedStatement preparedStatement = sessionRule.session().prepare(insert); + PreparedStatement preparedStatement = SESSION_RULE.session().prepare(insert); for (int i = 0; i < batchCount; i++) { builder.addStatement(preparedStatement.bind(i, i + 1)); } BatchStatement batchStatement = builder.build(); - sessionRule.session().execute(batchStatement); + SESSION_RULE.session().execute(batchStatement); verifyBatchInsert(); @@ -196,17 +197,17 @@ public void should_execute_batch_of_bound_statements_with_unset_values() { if (i % 20 == 0) { boundStatement = boundStatement.unset(1); } - builder.addStatement(boundStatement); + builder2.addStatement(boundStatement); } - sessionRule.session().execute(builder2.build()); + SESSION_RULE.session().execute(builder2.build()); Statement select = SimpleStatement.builder("SELECT * from test where k0 = ?") .addPositionalValue(name.getMethodName()) .build(); - ResultSet result = sessionRule.session().execute(select); + ResultSet result = SESSION_RULE.session().execute(select); List rows = result.all(); assertThat(rows).hasSize(100); @@ -230,7 +231,7 @@ public void should_execute_batch_of_bound_statements_with_named_variables() { // variable values. BatchStatementBuilder builder = BatchStatement.builder(DefaultBatchType.UNLOGGED); PreparedStatement preparedStatement = - sessionRule.session().prepare("INSERT INTO test (k0, k1, v) values (:k0, :k1, :v)"); + SESSION_RULE.session().prepare("INSERT INTO test (k0, k1, v) values (:k0, :k1, :v)"); for (int i = 0; i < batchCount; i++) { builder.addStatement( @@ -243,7 +244,7 @@ public void should_execute_batch_of_bound_statements_with_named_variables() { } BatchStatement batchStatement = builder.build(); - sessionRule.session().execute(batchStatement); + SESSION_RULE.session().execute(batchStatement); verifyBatchInsert(); } @@ -257,7 +258,7 @@ public void should_execute_batch_of_bound_and_simple_statements_with_variables() String.format( "INSERT INTO test (k0, k1, v) values ('%s', ? , ?)", name.getMethodName())) .build(); - PreparedStatement preparedStatement = sessionRule.session().prepare(insert); + PreparedStatement preparedStatement = SESSION_RULE.session().prepare(insert); for (int i = 0; i < batchCount; i++) { if (i % 2 == 1) { @@ -274,7 +275,7 @@ public void should_execute_batch_of_bound_and_simple_statements_with_variables() } BatchStatement batchStatement = builder.build(); - sessionRule.session().execute(batchStatement); + SESSION_RULE.session().execute(batchStatement); verifyBatchInsert(); } @@ -284,25 +285,53 @@ public void should_execute_cas_batch() { // Build a batch with CAS operations on the same partition. BatchStatementBuilder builder = BatchStatement.builder(DefaultBatchType.UNLOGGED); SimpleStatement insert = - SimpleStatement.builder( - String.format( - "INSERT INTO test (k0, k1, v) values ('%s', ? , ?) IF NOT EXISTS", - name.getMethodName())) + SimpleStatement.builder("INSERT INTO test (k0, k1, v) values (?, ?, ?) IF NOT EXISTS") .build(); - PreparedStatement preparedStatement = sessionRule.session().prepare(insert); + PreparedStatement preparedStatement = SESSION_RULE.session().prepare(insert); for (int i = 0; i < batchCount; i++) { - builder.addStatement(preparedStatement.bind(i, i + 1)); + builder.addStatement(preparedStatement.bind(name.getMethodName(), i, i + 1)); + } + + // Ensure LWT routing has a concrete routing key to compute replicas. + BoundStatement routingKeyStmt = preparedStatement.bind(name.getMethodName(), 0, 1); + builder.setRoutingKey(routingKeyStmt.getRoutingKey()); + builder.setSerialConsistencyLevel(ConsistencyLevel.SERIAL); + // Enforce LWT routing only for Cassandra where prepare metadata lacks LWT flags. + if (CcmBridge.isDistributionOf(BackendType.CASSANDRA)) { + builder.setIsLWT(true); } BatchStatement batchStatement = builder.build(); - ResultSet result = sessionRule.session().execute(batchStatement); + // Validate serial consistency and LWT routing on the batch itself. + assertThat(batchStatement.getSerialConsistencyLevel()).isEqualTo(ConsistencyLevel.SERIAL); + assertThat(batchStatement.isLWT()).isEqualTo(true); + assertThat(batchStatement.getRoutingKey()).isNotNull(); + + ResultSet result = SESSION_RULE.session().execute(batchStatement); + // Validate that executed request preserved serial consistency level. + assertThat(result.getExecutionInfo().getRequest()).isInstanceOf(Statement.class); + assertThat(((Statement) result.getExecutionInfo().getRequest()).getSerialConsistencyLevel()) + .isEqualTo(ConsistencyLevel.SERIAL); assertThat(result.wasApplied()).isTrue(); verifyBatchInsert(); - // re execute same batch and ensure wasn't applied. - result = sessionRule.session().execute(batchStatement); + // Rebuild an equivalent batch and ensure it isn't applied. + BatchStatementBuilder rerunBuilder = BatchStatement.builder(DefaultBatchType.UNLOGGED); + rerunBuilder.setSerialConsistencyLevel(ConsistencyLevel.SERIAL); + for (int i = 0; i < batchCount; i++) { + rerunBuilder.addStatement(preparedStatement.bind(name.getMethodName(), i, i + 1)); + } + // Use the same routing key to target the same partition for LWT. + rerunBuilder.setRoutingKey(routingKeyStmt.getRoutingKey()); + // Enforce LWT routing only for Cassandra where prepare metadata lacks LWT flags. + if (CcmBridge.isDistributionOf(BackendType.CASSANDRA)) { + rerunBuilder.setIsLWT(true); + } + BatchStatement rerunBatch = rerunBuilder.build(); + assertThat(rerunBatch.isLWT()).isEqualTo(true); + result = SESSION_RULE.session().execute(rerunBatch); assertThat(result.wasApplied()).isFalse(); } @@ -322,11 +351,11 @@ public void should_execute_counter_batch() { } BatchStatement batchStatement = builder.build(); - sessionRule.session().execute(batchStatement); + SESSION_RULE.session().execute(batchStatement); for (int i = 1; i <= 3; i++) { ResultSet result = - sessionRule + SESSION_RULE .session() .execute( String.format( @@ -356,7 +385,7 @@ public void should_fail_logged_batch_with_counter_increment() { } BatchStatement batchStatement = builder.build(); - sessionRule.session().execute(batchStatement); + SESSION_RULE.session().execute(batchStatement); } @Test(expected = InvalidQueryException.class) @@ -383,7 +412,7 @@ public void should_fail_counter_batch_with_non_counter_increment() { builder.addStatement(simpleInsert); BatchStatement batchStatement = builder.build(); - sessionRule.session().execute(batchStatement); + SESSION_RULE.session().execute(batchStatement); } @Test @@ -394,13 +423,13 @@ public void should_not_allow_unset_value_when_protocol_less_than_v4() { SessionUtils.configLoaderBuilder() .withString(DefaultDriverOption.PROTOCOL_VERSION, "V3") .build(); - try (CqlSession v3Session = SessionUtils.newSession(ccmRule, loader)) { + try (CqlSession v3Session = SessionUtils.newSession(CCM_RULE, loader)) { // Intentionally use fully qualified table here to avoid warnings as these are not supported // by v3 protocol version, see JAVA-3068 PreparedStatement prepared = v3Session.prepare( String.format( - "INSERT INTO %s.test (k0, k1, v) values (?, ?, ?)", sessionRule.keyspace())); + "INSERT INTO %s.test (k0, k1, v) values (?, ?, ?)", SESSION_RULE.keyspace())); BatchStatementBuilder builder = BatchStatement.builder(DefaultBatchType.LOGGED); builder.addStatements( @@ -427,7 +456,7 @@ private void verifyBatchInsert() { .addPositionalValue(name.getMethodName()) .build(); - ResultSet result = sessionRule.session().execute(select); + ResultSet result = SESSION_RULE.session().execute(select); List rows = result.all(); assertThat(rows).hasSize(100); diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java index 9e2d034a19f..7f24eb19978 100644 --- a/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java @@ -43,6 +43,7 @@ import com.datastax.oss.driver.api.testinfra.session.SessionUtils; import java.nio.ByteBuffer; import java.util.HashSet; +import java.util.List; import java.util.Set; import org.junit.BeforeClass; import org.junit.ClassRule; @@ -76,24 +77,26 @@ public static void setup() { } @Test - public void should_use_only_one_node_when_lwt_detected() { + public void should_use_replicas_when_lwt_detected() { assumeTrue( CcmBridge.isDistributionOf(BackendType.SCYLLA)); // Functionality only available in Scylla CqlSession session = SESSION_RULE.session(); int pk = 1234; ByteBuffer routingKey = TypeCodecs.INT.encodePrimitive(pk, ProtocolVersion.DEFAULT); TokenMap tokenMap = SESSION_RULE.session().getMetadata().getTokenMap().get(); - Node owner = - tokenMap.getReplicasList(session.getKeyspace().get(), routingKey).iterator().next(); + List replicas = tokenMap.getReplicasList(session.getKeyspace().get(), routingKey); PreparedStatement statement = SESSION_RULE .session() .prepare("INSERT INTO foo (pk, ck, v) VALUES (?, ?, ?) IF NOT EXISTS"); assertThat(statement.isLWT()).isTrue(); - for (int i = 0; i < 30; i++) { + Set coordinators = new HashSet<>(); + for (int i = 0; i < 100; i++) { ResultSet result = session.execute(statement.bind(pk, i, 123)); - assertThat(result.getExecutionInfo().getCoordinator()).isEqualTo(owner); + coordinators.add(result.getExecutionInfo().getCoordinator()); } + assertThat(coordinators).isSubsetOf(replicas); + assertThat(coordinators.size()).isGreaterThan(0).isLessThanOrEqualTo(replicas.size()); } @Test @@ -116,22 +119,22 @@ public void should_not_use_only_one_node_when_non_lwt() { } @Test - public void should_use_only_one_node_when_lwt_batch_detected() { + public void should_use_replicas_when_lwt_batch_detected() { assumeTrue( CcmBridge.isDistributionOf(BackendType.SCYLLA)); // Functionality only available in Scylla CqlSession session = SESSION_RULE.session(); int pk = 1234; ByteBuffer routingKey = TypeCodecs.INT.encodePrimitive(pk, ProtocolVersion.DEFAULT); TokenMap tokenMap = SESSION_RULE.session().getMetadata().getTokenMap().get(); - Node owner = - tokenMap.getReplicasList(session.getKeyspace().get(), routingKey).iterator().next(); + List replicas = tokenMap.getReplicasList(session.getKeyspace().get(), routingKey); PreparedStatement statement = SESSION_RULE .session() .prepare("INSERT INTO foo (pk, ck, v) VALUES (?, ?, ?) IF NOT EXISTS"); assertThat(statement.isLWT()).isTrue(); - for (int i = 0; i < 30; i++) { + Set coordinatorsLwt = new HashSet<>(); + for (int i = 0; i < 100; i++) { BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); SimpleStatement simpleStatement = SimpleStatement.newInstance( @@ -142,8 +145,10 @@ public void should_use_only_one_node_when_lwt_batch_detected() { batch = batch.add(statement.bind(pk, i, 123)); assertThat(batch.isLWT()).isTrue(); ResultSet result = session.execute(batch); - assertThat(result.getExecutionInfo().getCoordinator()).isEqualTo(owner); + coordinatorsLwt.add(result.getExecutionInfo().getCoordinator()); } + assertThat(coordinatorsLwt).isSubsetOf(replicas); + assertThat(coordinatorsLwt.size()).isGreaterThan(0).isLessThanOrEqualTo(replicas.size()); // Check if multiple coordinators are used when forcibly set to non-LWT Set coordinators = new HashSet<>(); diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingMultiDcIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingMultiDcIT.java new file mode 100644 index 00000000000..7299a36eb19 --- /dev/null +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingMultiDcIT.java @@ -0,0 +1,209 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright (C) 2026 ScyllaDB + * + * Modified by ScyllaDB + */ +package com.datastax.oss.driver.core.loadbalancing; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.Version; +import com.datastax.oss.driver.api.core.config.DefaultDriverOption; +import com.datastax.oss.driver.api.core.cql.BatchStatement; +import com.datastax.oss.driver.api.core.cql.BatchStatementBuilder; +import com.datastax.oss.driver.api.core.cql.BatchType; +import com.datastax.oss.driver.api.core.cql.PreparedStatement; +import com.datastax.oss.driver.api.core.cql.ResultSet; +import com.datastax.oss.driver.api.core.metadata.Node; +import com.datastax.oss.driver.api.core.metadata.TokenMap; +import com.datastax.oss.driver.api.core.type.codec.TypeCodecs; +import com.datastax.oss.driver.api.testinfra.ccm.CcmBridge; +import com.datastax.oss.driver.api.testinfra.ccm.CustomCcmRule; +import com.datastax.oss.driver.api.testinfra.requirement.BackendType; +import com.datastax.oss.driver.api.testinfra.session.SessionRule; +import com.datastax.oss.driver.api.testinfra.session.SessionUtils; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.RuleChain; +import org.junit.rules.TestRule; + +public class LWTLoadBalancingMultiDcIT { + private static final String LOCAL_DC = "dc1"; + private static final String KEYSPACE = "test"; + + private static final CustomCcmRule CCM_RULE = + CustomCcmRule.builder().withNodes(2, 1).build(); // 2 nodes in DC1, 1 node in DC2 + + private static final SessionRule SESSION_RULE = + SessionRule.builder(CCM_RULE) + .withKeyspace(false) + .withConfigLoader( + SessionUtils.configLoaderBuilder() + .withString(DefaultDriverOption.LOAD_BALANCING_LOCAL_DATACENTER, LOCAL_DC) + .withDuration(DefaultDriverOption.REQUEST_TIMEOUT, Duration.ofSeconds(30)) + .build()) + .build(); + + @ClassRule + public static final TestRule CHAIN = RuleChain.outerRule(CCM_RULE).around(SESSION_RULE); + + public static final int FIRST_TEST_PARTITION_KEY = 4242; + public static final int SECOND_TEST_PARTITION_KEY = 4343; + public static final int NUM_TEST_ITERATIONS = 30; + + @BeforeClass + public static void setup() { + CqlSession session = SESSION_RULE.session(); + + // Create multi-DC keyspace and table similarly to DefaultLoadBalancingPolicyIT. + if (CcmBridge.isDistributionOf(BackendType.SCYLLA) + && ((CcmBridge.SCYLLA_ENTERPRISE + && CcmBridge.getDistributionVersion().compareTo(Version.parse("2023.1.0")) >= 0) + || (!CcmBridge.SCYLLA_ENTERPRISE + && CcmBridge.getDistributionVersion().compareTo(Version.parse("6.1.0")) >= 0))) { + session.execute( + "CREATE KEYSPACE test " + + "WITH replication = {'class': 'NetworkTopologyStrategy', 'dc1': 2, 'dc2': 1} " + + "AND tablets = { 'enabled': false }"); + } else { + session.execute( + "CREATE KEYSPACE test " + + "WITH replication = {'class': 'NetworkTopologyStrategy', 'dc1': 2, 'dc2': 1}"); + } + + session.execute("CREATE TABLE test.foo (pk int, ck int, v int, PRIMARY KEY (pk, ck))"); + + // Wait for schema readiness + await() + .pollInterval(200, TimeUnit.MILLISECONDS) + .atMost(60, TimeUnit.SECONDS) + .untilAsserted( + () -> { + assertThat(session.getMetadata().getKeyspace(KEYSPACE)).isPresent(); + TokenMap tm = session.getMetadata().getTokenMap().get(); + ByteBuffer routingKey = + TypeCodecs.INT.encodePrimitive(FIRST_TEST_PARTITION_KEY, ProtocolVersion.DEFAULT); + Set replicas = + new HashSet<>(tm.getReplicasList(CqlIdentifier.fromCql(KEYSPACE), routingKey)); + assertThat(replicas).hasSize(3); // RF 2 in dc1, 1 in dc2 + assertThat(replicas.stream().filter(n -> LOCAL_DC.equals(n.getDatacenter()))) + .hasSizeGreaterThanOrEqualTo(1); + }); + } + + @Test + public void should_route_lwt_to_local_dc_replicas() { + int pk = FIRST_TEST_PARTITION_KEY; + CqlIdentifier keyspace = CqlIdentifier.fromCql(KEYSPACE); + ByteBuffer routingKey = TypeCodecs.INT.encodePrimitive(pk, ProtocolVersion.DEFAULT); + + TokenMap tokenMap = SESSION_RULE.session().getMetadata().getTokenMap().get(); + Set localReplicas = new HashSet<>(); + Set allReplicas = new HashSet<>(tokenMap.getReplicasList(keyspace, routingKey)); + for (Node replica : allReplicas) { + if (LOCAL_DC.equals(replica.getDatacenter())) { + localReplicas.add(replica); + } + } + assertThat(localReplicas).isNotEmpty(); + + PreparedStatement lwt = + SESSION_RULE + .session() + .prepare("INSERT INTO test.foo (pk, ck, v) VALUES (?, ?, ?) IF NOT EXISTS"); + // Cassandra does not expose LWT flag via prepare metadata; driver may not detect LWT. + if (!CcmBridge.isDistributionOf(BackendType.CASSANDRA)) { + assertThat(lwt.isLWT()).isTrue(); + } + + Set coordinators = new HashSet<>(); + Set coordinatorDcs = new HashSet<>(); + for (int i = 0; i < NUM_TEST_ITERATIONS; i++) { + ResultSet result = SESSION_RULE.session().execute(lwt.bind(pk, i, 7)); + Node coord = result.getExecutionInfo().getCoordinator(); + coordinators.add(coord); + coordinatorDcs.add(coord.getDatacenter()); + } + + assertThat(coordinators).isSubsetOf(allReplicas); + assertThat(coordinators).isSubsetOf(localReplicas); + assertThat(coordinatorDcs).containsOnly(LOCAL_DC); + } + + @Test + public void should_route_lwt_batch_to_local_dc_replicas() { + int pk = SECOND_TEST_PARTITION_KEY; + CqlIdentifier keyspace = CqlIdentifier.fromCql(KEYSPACE); + ByteBuffer routingKey = TypeCodecs.INT.encodePrimitive(pk, ProtocolVersion.DEFAULT); + + TokenMap tokenMap = SESSION_RULE.session().getMetadata().getTokenMap().get(); + Set localReplicas = new HashSet<>(); + Set allReplicas = new HashSet<>(tokenMap.getReplicasList(keyspace, routingKey)); + for (Node replica : allReplicas) { + if (LOCAL_DC.equals(replica.getDatacenter())) { + localReplicas.add(replica); + } + } + assertThat(localReplicas).isNotEmpty(); + + PreparedStatement lwt = + SESSION_RULE + .session() + .prepare("INSERT INTO test.foo (pk, ck, v) VALUES (?, ?, ?) IF NOT EXISTS"); + PreparedStatement nonLwtPrepared = + SESSION_RULE.session().prepare("INSERT INTO test.foo (pk, ck, v) VALUES (?, ?, ?)"); + + // Run a bunch of times to exercise load balancing. + Set coordinators = new HashSet<>(); + Set coordinatorDcs = new HashSet<>(); + for (int i = 0; i < NUM_TEST_ITERATIONS; i++) { + BatchStatementBuilder builder = + BatchStatement.builder(BatchType.UNLOGGED) + .setRoutingKeyspace(keyspace) + .setRoutingKey(routingKey) + .addStatement(nonLwtPrepared.bind(pk, 0, 101)) + .addStatement(lwt.bind(pk, i, 202)); + // Ensure LWT routing type on Cassandra where detection may be absent + if (CcmBridge.isDistributionOf(BackendType.CASSANDRA)) { + builder = builder.setIsLWT(true); + } + BatchStatement batch = builder.build(); + assertThat(batch.isLWT()).isTrue(); + + ResultSet result = SESSION_RULE.session().execute(batch); + Node coord = result.getExecutionInfo().getCoordinator(); + coordinators.add(coord); + coordinatorDcs.add(coord.getDatacenter()); + } + + assertThat(coordinators).isSubsetOf(allReplicas); + assertThat(coordinators).isSubsetOf(localReplicas); + assertThat(coordinatorDcs).containsOnly(LOCAL_DC); + } +} diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/metadata/NodeStateIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/metadata/NodeStateIT.java index e468e0a10d7..dc7590da2ec 100644 --- a/integration-tests/src/test/java/com/datastax/oss/driver/core/metadata/NodeStateIT.java +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/metadata/NodeStateIT.java @@ -57,6 +57,7 @@ import com.datastax.oss.simulacron.server.BoundNode; import com.datastax.oss.simulacron.server.RejectScope; import edu.umd.cs.findbugs.annotations.NonNull; +import edu.umd.cs.findbugs.annotations.Nullable; import java.io.IOException; import java.net.InetSocketAddress; import java.net.ServerSocket; @@ -703,7 +704,7 @@ public void stopIgnoring(Node node) { @NonNull @Override - public Queue newQueryPlan(@NonNull Request request, @NonNull Session session) { + public Queue newQueryPlan(@Nullable Request request, @Nullable Session session) { Object[] snapshot = liveNodes.toArray(); Queue queryPlan = new ConcurrentLinkedQueue<>(); int start = offset.getAndIncrement(); // Note: offset overflow won't be an issue in tests From 658b44c162bd0e7fd9b1782b2a836f7018757f37 Mon Sep 17 00:00:00 2001 From: Mikita Hradovich Date: Fri, 23 Jan 2026 22:24:34 +0100 Subject: [PATCH 4/4] refactor: Streamline LWT handling in batch statements --- .../driver/api/core/RequestRoutingType.java | 3 + .../driver/api/core/cql/BatchStatement.java | 9 - .../api/core/cql/BatchStatementBuilder.java | 19 +- .../api/core/cql/BoundStatementBuilder.java | 5 +- .../api/core/cql/PreparedStatement.java | 5 + .../oss/driver/api/core/cql/Statement.java | 6 +- .../driver/api/core/cql/StatementBuilder.java | 8 +- .../oss/driver/api/core/session/Request.java | 2 +- .../core/cql/DefaultBatchStatement.java | 106 +++++----- .../core/cql/DefaultBoundStatement.java | 49 ++--- .../core/cql/DefaultPreparedStatement.java | 14 +- .../core/cql/DefaultSimpleStatement.java | 10 +- .../DefaultLoadBalancingPolicy.java | 15 +- .../core/cql/DefaultBatchStatementTest.java | 191 +++++++++++++++++- .../oss/driver/core/cql/BatchStatementIT.java | 5 +- .../loadbalancing/LWTLoadBalancingIT.java | 3 +- .../LWTLoadBalancingMultiDcIT.java | 3 +- 17 files changed, 306 insertions(+), 147 deletions(-) diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingType.java b/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingType.java index d8f6d6b9d68..43bffe99589 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingType.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/RequestRoutingType.java @@ -1,6 +1,9 @@ package com.datastax.oss.driver.api.core; +/** The type of routing for a given request. */ public enum RequestRoutingType { + /** A regular (non-LWT) request. */ REGULAR, + /** A lightweight transaction (LWT) request. */ LWT } diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatement.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatement.java index e831ed62369..63afd227425 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatement.java @@ -280,13 +280,4 @@ default int computeSizeInBytes(@NonNull DriverContext context) { return size; } - - /** - * Overrides LWT state to a specific value. If unset or set to {@code null} the {@link - * Statement#isLWT()} method will infer result from the statments in the batch. - * - * @param newIsLWT new Boolean to set - * @return new BatchStatement with updated isLWT field. - */ - BatchStatement setIsLWT(Boolean newIsLWT); } diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java index abf3ef0892e..8e34c916ea1 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java @@ -18,7 +18,6 @@ package com.datastax.oss.driver.api.core.cql; import com.datastax.oss.driver.api.core.CqlIdentifier; -import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.internal.core.cql.DefaultBatchStatement; import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; import com.datastax.oss.driver.shaded.guava.common.collect.Iterables; @@ -40,7 +39,6 @@ public class BatchStatementBuilder extends StatementBuilder> statementsBuilder; private int statementsCount; - @Nullable private Boolean isLWT = null; public BatchStatementBuilder(@NonNull BatchType batchType) { this.batchType = batchType; @@ -77,19 +75,6 @@ public BatchStatementBuilder setKeyspace(@NonNull String keyspaceName) { return setKeyspace(CqlIdentifier.fromCql(keyspaceName)); } - /** - * Forces driver to see this batch as LWT or non-LWT. Note that if never explicitly set or set to - * {@code null}, the resulting {@code DefaultBatchStatement} will decide its LWT state based on - * contained statements. - * - * @return this builder; never {@code null}. - */ - @NonNull - public BatchStatementBuilder setIsLWT(Boolean newIsLWT) { - this.isLWT = newIsLWT; - return this; - } - /** * Adds a new statement to the batch. * @@ -153,8 +138,6 @@ public BatchStatementBuilder clearStatements() { @NonNull public BatchStatement build() { List> statements = statementsBuilder.build(); - RequestRoutingType routingType = - isLWT != null ? (isLWT ? RequestRoutingType.LWT : RequestRoutingType.REGULAR) : null; return new DefaultBatchStatement( batchType, statements, @@ -175,7 +158,7 @@ public BatchStatement build() { timeout, node, nowInSeconds, - routingType); + requestRoutingType); } public int getStatementsCount() { diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BoundStatementBuilder.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BoundStatementBuilder.java index fbbcccee018..58a3a2319a2 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BoundStatementBuilder.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BoundStatementBuilder.java @@ -20,6 +20,7 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.metadata.token.Token; import com.datastax.oss.driver.api.core.type.DataType; @@ -67,7 +68,8 @@ public BoundStatementBuilder( @Nullable ConsistencyLevel serialConsistencyLevel, @Nullable Duration timeout, @NonNull CodecRegistry codecRegistry, - @NonNull ProtocolVersion protocolVersion) { + @NonNull ProtocolVersion protocolVersion, + @Nullable RequestRoutingType requestRoutingType) { this.preparedStatement = preparedStatement; this.variableDefinitions = variableDefinitions; this.values = values; @@ -89,6 +91,7 @@ public BoundStatementBuilder( this.timeout = timeout; this.codecRegistry = codecRegistry; this.protocolVersion = protocolVersion; + this.requestRoutingType = requestRoutingType; } public BoundStatementBuilder(@NonNull BoundStatement template) { diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/PreparedStatement.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/PreparedStatement.java index 982db8b3b41..7ad77463aed 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/PreparedStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/PreparedStatement.java @@ -25,6 +25,7 @@ import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.DefaultProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.metadata.token.Partitioner; import edu.umd.cs.findbugs.annotations.NonNull; import edu.umd.cs.findbugs.annotations.Nullable; @@ -133,6 +134,10 @@ public interface PreparedStatement { */ boolean isLWT(); + /** Returns the request routing type for this prepared statement. */ + @Nullable + RequestRoutingType getRequestRoutingType(); + /** * Updates {@link #getResultMetadataId()} and {@link #getResultSetDefinitions()} atomically. * diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java index 68edc3a71a6..e88831e7925 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java @@ -536,7 +536,7 @@ default SelfT setNowInSeconds(int nowInSeconds) { @NonNull @CheckReturnValue @SuppressWarnings("unchecked") - default SelfT setRequestRoutingType(RequestRoutingType requestRoutingType) { + default SelfT setRequestRoutingType(@Nullable RequestRoutingType requestRoutingType) { return (SelfT) this; } @@ -555,7 +555,9 @@ default SelfT setRequestRoutingType(RequestRoutingType requestRoutingType) { * * @see Docs about LWT */ - boolean isLWT(); + default boolean isLWT() { + return getRequestRoutingType() == RequestRoutingType.LWT; // treating null as non-LWT + } /** * Calculates the approximate size in bytes that the statement will have when encoded. diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/StatementBuilder.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/StatementBuilder.java index ecfb5f57023..9894dd9c813 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/StatementBuilder.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/StatementBuilder.java @@ -62,7 +62,7 @@ public abstract class StatementBuilder< @Nullable protected Duration timeout; @Nullable protected Node node; protected int nowInSeconds = Statement.NO_NOW_IN_SECONDS; - @NonNull protected RequestRoutingType requestRoutingType = RequestRoutingType.REGULAR; + @Nullable protected RequestRoutingType requestRoutingType; protected StatementBuilder() { // nothing to do @@ -285,9 +285,9 @@ public SelfT setNowInSeconds(int nowInSeconds) { return self; } - /** @see Statement#setRequestRoutingType(RequestRoutingType) */ - public SelfT setRequestRoutingType(@NonNull RequestRoutingType routingType) { - this.requestRoutingType = routingType; + @NonNull + public SelfT setRequestRoutingType(@Nullable RequestRoutingType requestRoutingType) { + this.requestRoutingType = requestRoutingType; return self; } diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java b/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java index e92e3cc6814..c3035f2bf12 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java @@ -210,6 +210,6 @@ default Partitioner getPartitioner() { * * @return The routing type configured on this request */ - @NonNull + @Nullable RequestRoutingType getRequestRoutingType(); } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java index 582c326743b..cde8d91e4c9 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java @@ -43,6 +43,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Objects; import net.jcip.annotations.Immutable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -70,7 +71,8 @@ public class DefaultBatchStatement implements BatchStatement { private final Duration timeout; private final Node node; private final int nowInSeconds; - @NonNull private final RequestRoutingType routingType; + @Nullable private final RequestRoutingType requestRoutingType; + private RequestRoutingType cachedStatementsRequestRoutingType; public DefaultBatchStatement( BatchType batchType, @@ -92,7 +94,7 @@ public DefaultBatchStatement( Duration timeout, Node node, int nowInSeconds, - @NonNull RequestRoutingType routingType) { + @Nullable RequestRoutingType requestRoutingType) { for (BatchableStatement statement : statements) { if (statement != null && (statement.getConsistencyLevel() != null @@ -124,7 +126,7 @@ public DefaultBatchStatement( this.timeout = timeout; this.node = node; this.nowInSeconds = nowInSeconds; - this.routingType = routingType; + this.requestRoutingType = requestRoutingType; } @NonNull @@ -156,7 +158,7 @@ public BatchStatement setBatchType(@NonNull BatchType newBatchType) { timeout, node, nowInSeconds, - routingType); + requestRoutingType); } @NonNull @@ -182,7 +184,7 @@ public BatchStatement setKeyspace(@Nullable CqlIdentifier newKeyspace) { timeout, node, nowInSeconds, - routingType); + requestRoutingType); } @NonNull @@ -212,7 +214,7 @@ public BatchStatement add(@NonNull BatchableStatement statement) { timeout, node, nowInSeconds, - routingType); + requestRoutingType); } } @@ -246,7 +248,7 @@ public BatchStatement addAll(@NonNull Iterable> timeout, node, nowInSeconds, - routingType); + requestRoutingType); } } @@ -278,7 +280,7 @@ public BatchStatement clear() { timeout, node, nowInSeconds, - routingType); + requestRoutingType); } @NonNull @@ -315,7 +317,7 @@ public BatchStatement setPagingState(ByteBuffer newPagingState) { timeout, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -346,7 +348,7 @@ public BatchStatement setPageSize(int newPageSize) { timeout, node, nowInSeconds, - routingType); + requestRoutingType); } @Nullable @@ -378,7 +380,7 @@ public BatchStatement setConsistencyLevel(@Nullable ConsistencyLevel newConsiste timeout, node, nowInSeconds, - routingType); + requestRoutingType); } @Nullable @@ -411,7 +413,7 @@ public BatchStatement setSerialConsistencyLevel( timeout, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -442,7 +444,7 @@ public BatchStatement setExecutionProfileName(@Nullable String newConfigProfileN timeout, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -473,7 +475,7 @@ public DefaultBatchStatement setExecutionProfile(@Nullable DriverExecutionProfil timeout, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -539,7 +541,7 @@ public BatchStatement setRoutingKeyspace(CqlIdentifier newRoutingKeyspace) { timeout, node, nowInSeconds, - routingType); + requestRoutingType); } @NonNull @@ -565,7 +567,7 @@ public BatchStatement setNode(@Nullable Node newNode) { timeout, newNode, nowInSeconds, - routingType); + requestRoutingType); } @Nullable @@ -612,7 +614,7 @@ public BatchStatement setRoutingKey(ByteBuffer newRoutingKey) { timeout, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -653,7 +655,7 @@ public BatchStatement setRoutingToken(Token newRoutingToken) { timeout, node, nowInSeconds, - routingType); + requestRoutingType); } @NonNull @@ -685,7 +687,7 @@ public DefaultBatchStatement setCustomPayload(@NonNull Map n timeout, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -722,7 +724,7 @@ public DefaultBatchStatement setIdempotent(Boolean newIdempotence) { timeout, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -753,7 +755,7 @@ public BatchStatement setTracing(boolean newTracing) { timeout, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -784,7 +786,7 @@ public BatchStatement setQueryTimestamp(long newTimestamp) { timeout, node, nowInSeconds, - routingType); + requestRoutingType); } @NonNull @@ -810,7 +812,7 @@ public BatchStatement setTimeout(@Nullable Duration newTimeout) { newTimeout, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -841,13 +843,31 @@ public BatchStatement setNowInSeconds(int newNowInSeconds) { timeout, node, newNowInSeconds, - routingType); + requestRoutingType); } - @NonNull + /** + * Returns the request routing type for this batch statement based on {@link + * DefaultBatchStatement#isLWT()} implementation while maintaining non-null contract. + * + * @return the request routing type, never null + */ + @Nullable @Override public RequestRoutingType getRequestRoutingType() { - return routingType; + if (Objects.nonNull(requestRoutingType)) { + return requestRoutingType; + } else if (Objects.isNull( + cachedStatementsRequestRoutingType)) { // Immutability of the statement list and statements + // allows us to cache the result + cachedStatementsRequestRoutingType = + statements.stream() + .map(Statement::getRequestRoutingType) + .filter((rt) -> Objects.nonNull(rt) && rt == RequestRoutingType.LWT) + .findFirst() + .orElse(RequestRoutingType.REGULAR); + } + return cachedStatementsRequestRoutingType; } @NonNull @@ -875,38 +895,4 @@ public BatchStatement setRequestRoutingType(RequestRoutingType requestRoutingTyp nowInSeconds, requestRoutingType); } - - @NonNull - @Override - public BatchStatement setIsLWT(Boolean newIsLWT) { - RequestRoutingType routingType = - newIsLWT != null ? (newIsLWT ? RequestRoutingType.LWT : RequestRoutingType.REGULAR) : null; - return new DefaultBatchStatement( - batchType, - statements, - executionProfileName, - executionProfile, - keyspace, - routingKeyspace, - routingKey, - routingToken, - customPayload, - idempotent, - tracing, - timestamp, - pagingState, - pageSize, - consistencyLevel, - serialConsistencyLevel, - timeout, - node, - nowInSeconds, - routingType); - } - - @Override - public boolean isLWT() { - if (routingType != null) return routingType == RequestRoutingType.LWT; - return statements.stream().anyMatch(Statement::isLWT); - } } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java index c2024dcf8b0..2c3ad902f39 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java @@ -44,6 +44,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Objects; import net.jcip.annotations.Immutable; @Immutable @@ -70,7 +71,7 @@ public class DefaultBoundStatement implements BoundStatement { private final ProtocolVersion protocolVersion; private final Node node; private final int nowInSeconds; - @NonNull private final RequestRoutingType routingType; + @Nullable private final RequestRoutingType requestRoutingType; public DefaultBoundStatement( PreparedStatement preparedStatement, @@ -94,7 +95,7 @@ public DefaultBoundStatement( ProtocolVersion protocolVersion, Node node, int nowInSeconds, - @NonNull RequestRoutingType routingType) { + @Nullable RequestRoutingType requestRoutingType) { this.preparedStatement = preparedStatement; this.variableDefinitions = variableDefinitions; this.values = values; @@ -116,7 +117,7 @@ public DefaultBoundStatement( this.protocolVersion = protocolVersion; this.node = node; this.nowInSeconds = nowInSeconds; - this.routingType = routingType; + this.requestRoutingType = requestRoutingType; } @Override @@ -212,7 +213,7 @@ public BoundStatement setBytesUnsafe(int i, ByteBuffer v) { protocolVersion, node, nowInSeconds, - routingType); + requestRoutingType); } @NonNull @@ -257,7 +258,7 @@ public BoundStatement setExecutionProfileName(@Nullable String newConfigProfileN protocolVersion, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -290,7 +291,7 @@ public BoundStatement setExecutionProfile(@Nullable DriverExecutionProfile newPr protocolVersion, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -341,7 +342,7 @@ public BoundStatement setRoutingKeyspace(@Nullable CqlIdentifier newRoutingKeysp protocolVersion, node, nowInSeconds, - routingType); + requestRoutingType); } @NonNull @@ -369,7 +370,7 @@ public BoundStatement setNode(@Nullable Node newNode) { protocolVersion, newNode, nowInSeconds, - routingType); + requestRoutingType); } @Nullable @@ -430,7 +431,7 @@ public BoundStatement setRoutingKey(@Nullable ByteBuffer newRoutingKey) { protocolVersion, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -463,7 +464,7 @@ public BoundStatement setRoutingToken(@Nullable Token newRoutingToken) { protocolVersion, node, nowInSeconds, - routingType); + requestRoutingType); } @NonNull @@ -497,7 +498,7 @@ public BoundStatement setCustomPayload(@NonNull Map newCusto protocolVersion, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -530,7 +531,7 @@ public BoundStatement setIdempotent(@Nullable Boolean newIdempotence) { protocolVersion, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -563,7 +564,7 @@ public BoundStatement setTracing(boolean newTracing) { protocolVersion, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -596,7 +597,7 @@ public BoundStatement setQueryTimestamp(long newTimestamp) { protocolVersion, node, nowInSeconds, - routingType); + requestRoutingType); } @Nullable @@ -630,7 +631,7 @@ public BoundStatement setTimeout(@Nullable Duration newTimeout) { protocolVersion, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -663,7 +664,7 @@ public BoundStatement setPagingState(@Nullable ByteBuffer newPagingState) { protocolVersion, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -696,7 +697,7 @@ public BoundStatement setPageSize(int newPageSize) { protocolVersion, node, nowInSeconds, - routingType); + requestRoutingType); } @Nullable @@ -730,7 +731,7 @@ public BoundStatement setConsistencyLevel(@Nullable ConsistencyLevel newConsiste protocolVersion, node, nowInSeconds, - routingType); + requestRoutingType); } @Nullable @@ -765,7 +766,7 @@ public BoundStatement setSerialConsistencyLevel( protocolVersion, node, nowInSeconds, - routingType); + requestRoutingType); } @Override @@ -798,18 +799,20 @@ public BoundStatement setNowInSeconds(int newNowInSeconds) { protocolVersion, node, newNowInSeconds, - routingType); + requestRoutingType); } - @NonNull + @Nullable @Override public RequestRoutingType getRequestRoutingType() { - return routingType; + return Objects.nonNull(requestRoutingType) + ? requestRoutingType + : preparedStatement.getRequestRoutingType(); } @NonNull @Override - public BoundStatement setRequestRoutingType(@NonNull RequestRoutingType requestRoutingType) { + public BoundStatement setRequestRoutingType(@Nullable RequestRoutingType requestRoutingType) { return new DefaultBoundStatement( preparedStatement, variableDefinitions, diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java index 3994a5683ba..754a89ac228 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java @@ -48,6 +48,7 @@ import com.datastax.oss.driver.internal.core.session.RepreparePayload; import com.datastax.oss.driver.shaded.guava.common.base.Splitter; import edu.umd.cs.findbugs.annotations.NonNull; +import edu.umd.cs.findbugs.annotations.Nullable; import java.nio.ByteBuffer; import java.time.Duration; import java.util.List; @@ -83,7 +84,7 @@ public class DefaultPreparedStatement implements PreparedStatement { private final ConsistencyLevel serialConsistencyLevelForBoundStatements; private final Duration timeoutForBoundStatements; private final Partitioner partitioner; - @NonNull private final RequestRoutingType requestRoutingType; + @Nullable private final RequestRoutingType requestRoutingType; private volatile boolean skipMetadata; public DefaultPreparedStatement( @@ -111,7 +112,7 @@ public DefaultPreparedStatement( boolean areBoundStatementsTracing, CodecRegistry codecRegistry, ProtocolVersion protocolVersion, - @NonNull RequestRoutingType requestRoutingType) { + @Nullable RequestRoutingType requestRoutingType) { this.id = id; this.partitionKeyIndices = partitionKeyIndices; // It's important that we keep a reference to this object, so that it only gets evicted from @@ -192,6 +193,12 @@ public boolean isLWT() { return requestRoutingType == RequestRoutingType.LWT; } + @Nullable + @Override + public RequestRoutingType getRequestRoutingType() { + return requestRoutingType; + } + @Override public void setResultMetadata( @NonNull ByteBuffer newResultMetadataId, @NonNull ColumnDefinitions newResultSetDefinitions) { @@ -257,7 +264,8 @@ public BoundStatementBuilder boundStatementBuilder(@NonNull Object... values) { serialConsistencyLevelForBoundStatements, timeoutForBoundStatements, codecRegistry, - protocolVersion); + protocolVersion, + requestRoutingType); } public RepreparePayload getRepreparePayload() { diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java index f1a0495d6ed..0268689d86f 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java @@ -65,7 +65,7 @@ public class DefaultSimpleStatement implements SimpleStatement { private final Duration timeout; private final Node node; private final int nowInSeconds; - @NonNull private final RequestRoutingType requestRoutingType; + @Nullable private final RequestRoutingType requestRoutingType; /** @see SimpleStatement#builder(String) */ public DefaultSimpleStatement( @@ -89,7 +89,7 @@ public DefaultSimpleStatement( Duration timeout, Node node, int nowInSeconds, - @NonNull RequestRoutingType requestRoutingType) { + @Nullable RequestRoutingType requestRoutingType) { if (!positionalValues.isEmpty() && !namedValues.isEmpty()) { throw new IllegalArgumentException("Can't have both positional and named values"); } @@ -773,7 +773,7 @@ public SimpleStatement setNowInSeconds(int newNowInSeconds) { requestRoutingType); } - @NonNull + @Nullable @Override public RequestRoutingType getRequestRoutingType() { return requestRoutingType; @@ -781,7 +781,7 @@ public RequestRoutingType getRequestRoutingType() { @NonNull @Override - public SimpleStatement setRequestRoutingType(@NonNull RequestRoutingType requestRoutingType) { + public SimpleStatement setRequestRoutingType(@Nullable RequestRoutingType requestRoutingType) { return new DefaultSimpleStatement( query, positionalValues, @@ -808,7 +808,7 @@ public SimpleStatement setRequestRoutingType(@NonNull RequestRoutingType request @Override public boolean isLWT() { - return false; + return requestRoutingType == RequestRoutingType.LWT; } public static Map wrapKeys(Map namedValues) { diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java index 66a1c13b3eb..f798ff033c2 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java @@ -20,6 +20,7 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MINUTES; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverExecutionProfile; import com.datastax.oss.driver.api.core.context.DriverContext; @@ -156,12 +157,10 @@ public RequestRoutingMethod getDefaultLWTRequestRoutingMethod(@Nullable Request if (request == null) { return RequestRoutingMethod.REGULAR; } - switch (request.getRequestRoutingType()) { - case LWT: - return lwtRequestRoutingMethod; - case REGULAR: - default: - return RequestRoutingMethod.REGULAR; + if (request.getRequestRoutingType() == RequestRoutingType.LWT) { + return lwtRequestRoutingMethod; + } else { + return RequestRoutingMethod.REGULAR; } } @@ -171,6 +170,7 @@ public Queue newQueryPlan(@Nullable Request request, @Nullable Session ses switch (getDefaultLWTRequestRoutingMethod(request)) { case PRESERVE_REPLICA_ORDER: return newQueryPlanPreserveReplicas(request, session); + case REGULAR: default: return newQueryPlanRegular(request, session); } @@ -416,8 +416,7 @@ protected class NodeResponseRateSample { @VisibleForTesting protected final OptionalLong newest; private NodeResponseRateSample() { - long now = nanoTime(); - this.oldest = now; + this.oldest = nanoTime(); this.newest = OptionalLong.empty(); } diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatementTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatementTest.java index 2377968b4fc..3f38ddaf3cb 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatementTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatementTest.java @@ -26,6 +26,7 @@ import ch.qos.logback.classic.Level; import ch.qos.logback.classic.spi.ILoggingEvent; import com.datastax.oss.driver.api.core.DefaultConsistencyLevel; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.cql.BatchStatement; import com.datastax.oss.driver.api.core.cql.BatchStatementBuilder; import com.datastax.oss.driver.api.core.cql.BatchType; @@ -109,43 +110,50 @@ public void should_infer_lwt_status() { SimpleStatement.builder("SELECT * FROM some_table WHERE a = ?").build(); BoundStatement lwtBoundStatement = mock(DefaultBoundStatement.class); when(lwtBoundStatement.isLWT()).thenReturn(true); + when(lwtBoundStatement.getRequestRoutingType()).thenReturn(RequestRoutingType.LWT); // Without LWT statements added BatchStatementBuilder batchStatementBuilder = new BatchStatementBuilder(BatchType.UNLOGGED); batchStatementBuilder.addStatement(simpleStatement); - assertThat(batchStatementBuilder.build().isLWT()).isFalse(); + BatchStatement batch = batchStatementBuilder.build(); + assertThat(batch.isLWT()).isFalse(); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); // Check if implicitly set to true after adding LWT bound statement batchStatementBuilder.addStatement(lwtBoundStatement); assertThat(batchStatementBuilder.build().isLWT()).isTrue(); // Check if explicit set overrides implicit resolution - batchStatementBuilder.setIsLWT(false); - assertThat(batchStatementBuilder.build().isLWT()).isFalse(); + batchStatementBuilder.setRequestRoutingType(RequestRoutingType.REGULAR); + batch = batchStatementBuilder.build(); + assertThat(batch.isLWT()).isFalse(); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); batchStatementBuilder = new BatchStatementBuilder(BatchType.UNLOGGED); batchStatementBuilder.addStatement(simpleStatement); - batchStatementBuilder.setIsLWT(true); - assertThat(batchStatementBuilder.build().isLWT()).isTrue(); + batchStatementBuilder.setRequestRoutingType(RequestRoutingType.LWT); + batch = batchStatementBuilder.build(); + assertThat(batch.isLWT()).isTrue(); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.LWT); // Check if explicit set remains after clear assertThat(batchStatementBuilder.build().clear().isLWT()).isTrue(); // Similar checks without using builder - BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = BatchStatement.newInstance(BatchType.UNLOGGED); assertThat(batch.isLWT()).isFalse(); batch = batch.add(simpleStatement); assertThat(batch.isLWT()).isFalse(); batch = batch.add(lwtBoundStatement); assertThat(batch.isLWT()).isTrue(); - batch = batch.setIsLWT(false); + batch = batch.setRequestRoutingType(RequestRoutingType.REGULAR); assertThat(batch.isLWT()).isFalse(); batch = batch.add(lwtBoundStatement); assertThat(batch.isLWT()).isFalse(); - batch = batch.setIsLWT(true); + batch = batch.setRequestRoutingType(RequestRoutingType.LWT); assertThat(batch.isLWT()).isTrue(); batch = batch.clear(); assertThat(batch.isLWT()).isTrue(); - batch = batch.setIsLWT(null); + batch = batch.setRequestRoutingType(null); assertThat(batch.isLWT()).isFalse(); assertThat(BatchStatement.newInstance(BatchType.UNLOGGED).isLWT()).isFalse(); @@ -155,4 +163,169 @@ public void should_infer_lwt_status() { assertThat(BatchStatement.newInstance(BatchType.LOGGED, lwtBoundStatement).isLWT()).isTrue(); assertThat(BatchStatement.newInstance(BatchType.COUNTER, lwtBoundStatement).isLWT()).isTrue(); } + + @Test + public void should_handle_null_routing_type_in_empty_batch() { + // Empty batch should return REGULAR (not null) and isLWT should be false + BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); + assertThat(batch.getRequestRoutingType()).isNotNull(); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + assertThat(batch.isLWT()).isFalse(); + + // Same for other batch types + batch = BatchStatement.newInstance(BatchType.LOGGED); + assertThat(batch.getRequestRoutingType()).isNotNull(); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + assertThat(batch.isLWT()).isFalse(); + + batch = BatchStatement.newInstance(BatchType.COUNTER); + assertThat(batch.getRequestRoutingType()).isNotNull(); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + assertThat(batch.isLWT()).isFalse(); + } + + @Test + public void should_handle_statements_with_null_routing_types() { + // Create statements that return null routing type + BoundStatement nullRoutingStatement1 = mock(DefaultBoundStatement.class); + when(nullRoutingStatement1.isLWT()).thenReturn(false); + when(nullRoutingStatement1.getRequestRoutingType()).thenReturn(null); + + BoundStatement nullRoutingStatement2 = mock(DefaultBoundStatement.class); + when(nullRoutingStatement2.isLWT()).thenReturn(false); + when(nullRoutingStatement2.getRequestRoutingType()).thenReturn(null); + + // Batch with only null routing type statements should return REGULAR + BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = batch.add(nullRoutingStatement1); + batch = batch.add(nullRoutingStatement2); + + assertThat(batch.getRequestRoutingType()).isNotNull(); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + assertThat(batch.isLWT()).isFalse(); + } + + @Test + public void should_handle_mixed_null_and_non_null_routing_types() { + // Create statements with different routing types + BoundStatement nullRoutingStatement = mock(DefaultBoundStatement.class); + when(nullRoutingStatement.isLWT()).thenReturn(false); + when(nullRoutingStatement.getRequestRoutingType()).thenReturn(null); + + BoundStatement regularStatement = mock(DefaultBoundStatement.class); + when(regularStatement.isLWT()).thenReturn(false); + when(regularStatement.getRequestRoutingType()).thenReturn(RequestRoutingType.REGULAR); + + BoundStatement lwtStatement = mock(DefaultBoundStatement.class); + when(lwtStatement.isLWT()).thenReturn(true); + when(lwtStatement.getRequestRoutingType()).thenReturn(RequestRoutingType.LWT); + + // Test 1: null + regular -> REGULAR + BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = batch.add(nullRoutingStatement); + batch = batch.add(regularStatement); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + assertThat(batch.isLWT()).isFalse(); + + // Test 2: null + LWT -> LWT (LWT should be detected) + batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = batch.add(nullRoutingStatement); + batch = batch.add(lwtStatement); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.LWT); + assertThat(batch.isLWT()).isTrue(); + + // Test 3: regular + null + LWT -> LWT (LWT should be detected regardless of order) + batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = batch.add(regularStatement); + batch = batch.add(nullRoutingStatement); + batch = batch.add(lwtStatement); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.LWT); + assertThat(batch.isLWT()).isTrue(); + + // Test 4: LWT + null + regular -> LWT (order shouldn't matter) + batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = batch.add(lwtStatement); + batch = batch.add(nullRoutingStatement); + batch = batch.add(regularStatement); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.LWT); + assertThat(batch.isLWT()).isTrue(); + } + + @Test + public void should_handle_explicit_null_routing_type_override() { + BoundStatement lwtStatement = mock(DefaultBoundStatement.class); + when(lwtStatement.isLWT()).thenReturn(true); + when(lwtStatement.getRequestRoutingType()).thenReturn(RequestRoutingType.LWT); + + BoundStatement regularStatement = mock(DefaultBoundStatement.class); + when(regularStatement.isLWT()).thenReturn(false); + when(regularStatement.getRequestRoutingType()).thenReturn(RequestRoutingType.REGULAR); + + // Test 1: Batch with LWT statement, then set routing type to null + // Should fall back to inference and detect LWT + BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = batch.add(lwtStatement); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.LWT); + assertThat(batch.isLWT()).isTrue(); + + batch = batch.setRequestRoutingType(null); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.LWT); + assertThat(batch.isLWT()).isTrue(); + + // Test 2: Batch with regular statement, set routing type to null + // Should infer REGULAR + batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = batch.add(regularStatement); + batch = batch.setRequestRoutingType(null); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + assertThat(batch.isLWT()).isFalse(); + + // Test 3: Empty batch with explicit null routing type + // Should return REGULAR + batch = BatchStatement.newInstance(BatchType.UNLOGGED); + batch = batch.setRequestRoutingType(null); + assertThat(batch.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + assertThat(batch.isLWT()).isFalse(); + } + + @Test + public void should_return_non_null_routing_type_consistently() { + // Verify that getRequestRoutingType never returns null + SimpleStatement simpleStatement = + SimpleStatement.builder("SELECT * FROM some_table WHERE a = ?").build(); + + BoundStatement lwtStatement = mock(DefaultBoundStatement.class); + when(lwtStatement.isLWT()).thenReturn(true); + when(lwtStatement.getRequestRoutingType()).thenReturn(RequestRoutingType.LWT); + + BoundStatement nullRoutingStatement = mock(DefaultBoundStatement.class); + when(nullRoutingStatement.isLWT()).thenReturn(false); + when(nullRoutingStatement.getRequestRoutingType()).thenReturn(null); + + // Test various batch configurations + BatchStatement batch1 = BatchStatement.newInstance(BatchType.UNLOGGED); + assertThat(batch1.getRequestRoutingType()).isNotNull(); + + BatchStatement batch2 = batch1.add(simpleStatement); + assertThat(batch2.getRequestRoutingType()).isNotNull(); + + BatchStatement batch3 = batch2.add(lwtStatement); + assertThat(batch3.getRequestRoutingType()).isNotNull(); + + BatchStatement batch4 = batch3.setRequestRoutingType(null); + assertThat(batch4.getRequestRoutingType()).isNotNull(); + + BatchStatement batch5 = + BatchStatement.newInstance(BatchType.UNLOGGED).add(nullRoutingStatement); + assertThat(batch5.getRequestRoutingType()).isNotNull(); + assertThat(batch5.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + + BatchStatement batch6 = batch5.setRequestRoutingType(RequestRoutingType.LWT); + assertThat(batch6.getRequestRoutingType()).isNotNull(); + assertThat(batch6.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.LWT); + + BatchStatement batch7 = batch6.setRequestRoutingType(null); + assertThat(batch7.getRequestRoutingType()).isNotNull(); + assertThat(batch7.getRequestRoutingType()).isEqualByComparingTo(RequestRoutingType.REGULAR); + } } diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/cql/BatchStatementIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/cql/BatchStatementIT.java index ffa2e8046f8..5ef33103598 100644 --- a/integration-tests/src/test/java/com/datastax/oss/driver/core/cql/BatchStatementIT.java +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/cql/BatchStatementIT.java @@ -31,6 +31,7 @@ import ch.qos.logback.classic.spi.ILoggingEvent; import com.datastax.oss.driver.api.core.ConsistencyLevel; import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverConfigLoader; import com.datastax.oss.driver.api.core.cql.BatchStatement; @@ -299,7 +300,7 @@ public void should_execute_cas_batch() { builder.setSerialConsistencyLevel(ConsistencyLevel.SERIAL); // Enforce LWT routing only for Cassandra where prepare metadata lacks LWT flags. if (CcmBridge.isDistributionOf(BackendType.CASSANDRA)) { - builder.setIsLWT(true); + builder.setRequestRoutingType(RequestRoutingType.LWT); } BatchStatement batchStatement = builder.build(); @@ -327,7 +328,7 @@ public void should_execute_cas_batch() { rerunBuilder.setRoutingKey(routingKeyStmt.getRoutingKey()); // Enforce LWT routing only for Cassandra where prepare metadata lacks LWT flags. if (CcmBridge.isDistributionOf(BackendType.CASSANDRA)) { - rerunBuilder.setIsLWT(true); + rerunBuilder.setRequestRoutingType(RequestRoutingType.LWT); } BatchStatement rerunBatch = rerunBuilder.build(); assertThat(rerunBatch.isLWT()).isEqualTo(true); diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java index 7f24eb19978..0586b3236bb 100644 --- a/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java @@ -27,6 +27,7 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.cql.BatchStatement; import com.datastax.oss.driver.api.core.cql.BatchType; import com.datastax.oss.driver.api.core.cql.PreparedStatement; @@ -161,7 +162,7 @@ public void should_use_replicas_when_lwt_batch_detected() { assertThat(simpleStatement.isLWT()).isFalse(); batch = batch.add(simpleStatement); batch = batch.add(statement.bind(pk, i, 123)); - batch = batch.setIsLWT(false); + batch = batch.setRequestRoutingType(RequestRoutingType.REGULAR); assertThat(batch.isLWT()).isFalse(); ResultSet result = session.execute(batch); coordinators.add(result.getExecutionInfo().getCoordinator()); diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingMultiDcIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingMultiDcIT.java index 7299a36eb19..011c1f3ea0a 100644 --- a/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingMultiDcIT.java +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingMultiDcIT.java @@ -27,6 +27,7 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.RequestRoutingType; import com.datastax.oss.driver.api.core.Version; import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.cql.BatchStatement; @@ -191,7 +192,7 @@ public void should_route_lwt_batch_to_local_dc_replicas() { .addStatement(lwt.bind(pk, i, 202)); // Ensure LWT routing type on Cassandra where detection may be absent if (CcmBridge.isDistributionOf(BackendType.CASSANDRA)) { - builder = builder.setIsLWT(true); + builder = builder.setRequestRoutingType(RequestRoutingType.LWT); } BatchStatement batch = builder.build(); assertThat(batch.isLWT()).isTrue();