From e7fb66b67fef4b6856204991e1652b279b81b8cc Mon Sep 17 00:00:00 2001 From: Jerry Peng Date: Fri, 29 May 2026 02:40:05 +0000 Subject: [PATCH] [SPARK-XXXXX][SHUFFLE] Add StreamingShuffleManager, MultiShuffleManager, and logging mixin Introduces the streaming shuffle manager layer: * StreamingShuffleManager - ShuffleManager implementation for streaming shuffle. getWriter/getReader are stubbed here and implemented in the follow-up push-path and pull-path PRs. * MultiShuffleManager - routes each shuffle to either the batch (SortShuffleManager) or streaming manager. * TaskContextAwareLogging - logging mixin that prefixes queryId / shuffleId / stageId / taskId. * SparkEnv - expose the (already-merged) StreamingShuffleOutputTracker to executors when the configured shuffle manager is the streaming or multi manager. * Streaming shuffle error conditions and the STREAMING_QUERY_ID log key. * StreamingShuffleManagerSuite and MultiShuffleManagerSuite covering the manager APIs, routing, and the SparkEnv tracker-initialization gating. Co-authored-by: Isaac --- .../org/apache/spark/internal/LogKeys.java | 1 + .../resources/error/error-conditions.json | 12 ++ .../scala/org/apache/spark/SparkEnv.scala | 49 ++++++ .../streaming/MultiShuffleManager.scala | 149 +++++++++++++++++ .../streaming/StreamingShuffleManager.scala | 150 ++++++++++++++++++ .../streaming/TaskContextAwareLogging.scala | 109 +++++++++++++ .../streaming/MultiShuffleManagerSuite.scala | 70 ++++++++ .../StreamingShuffleManagerSuite.scala | 119 ++++++++++++++ 8 files changed, 659 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/streaming/MultiShuffleManager.scala create mode 100644 core/src/main/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManager.scala create mode 100644 core/src/main/scala/org/apache/spark/shuffle/streaming/TaskContextAwareLogging.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/streaming/MultiShuffleManagerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManagerSuite.scala diff --git a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java index d8ce9d025af98..37064bf776312 100644 --- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java +++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java @@ -794,6 +794,7 @@ public enum LogKeys implements LogKey { STREAMING_DATA_SOURCE_NAME, STREAMING_OFFSETS_END, STREAMING_OFFSETS_START, + STREAMING_QUERY_ID, STREAMING_QUERY_PROGRESS, STREAMING_SOURCE, STREAMING_TABLE, diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 71dbf1a1ebce3..f6ee16a8c4c2c 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -7130,6 +7130,18 @@ }, "sqlState" : "0A000" }, + "STREAMING_SHUFFLE_INCORRECT_SEQUENCE_NUMBER" : { + "message" : [ + "Streaming shuffle between writer and reader expected to have sequence number , but the actual sequence number is . Please verify that the messages are sent in order." + ], + "sqlState" : "XXKST" + }, + "STREAMING_SHUFFLE_UNEXPECTED_MESSAGE_TYPE" : { + "message" : [ + "Unexpected message type encountered during streaming shuffle." + ], + "sqlState" : "XXKST" + }, "STREAMING_STATEFUL_OPERATOR_MISSING_STATE_DIRECTORY" : { "message" : [ "Cannot restart streaming query with stateful operators because the state directory is empty or missing.", diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4e56c88501ede..52532c4d4e00e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -46,6 +46,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.streaming.{MultiShuffleManager, StreamingShuffleManager} import org.apache.spark.storage._ import org.apache.spark.udf.worker.UDFWorkerSpecification import org.apache.spark.udf.worker.core.{UDFDispatcherFactory, UDFDispatcherManager, WorkerDispatcher} @@ -181,6 +182,7 @@ class SparkEnv ( pythonWorkers.values.foreach(_.stop()) udfDispatcherManager.foreach(_.close()) mapOutputTracker.stop() + _streamingShuffleOutputTracker.foreach(_.stop()) if (shuffleManager != null) { shuffleManager.stop() } @@ -299,6 +301,53 @@ class SparkEnv ( // Signal that the ShuffleManager has been initialized shuffleManagerInitLatch.countDown() } + initializeStreamingShuffleOutputTracker() + } + + // Holds the streaming shuffle output tracker, which is only present when the configured + // shuffle manager requires it (i.e., StreamingShuffleManager or MultiShuffleManager). + @volatile private var _streamingShuffleOutputTracker: Option[StreamingShuffleOutputTracker] = + None + + def streamingShuffleOutputTracker: Option[StreamingShuffleOutputTracker] = + _streamingShuffleOutputTracker + + /** + * Initialize the StreamingShuffleOutputTracker if the configured shuffle manager requires one + * and one does not already exist. This method is idempotent -- calling it multiple times is safe. + * + * This is separated from initializeShuffleManager() to allow the driver to register the + * tracker's RPC endpoint before the SHUFFLE_MANAGER config change propagates to executors, + * eliminating the race condition where executors try to look up the endpoint before the driver + * has registered it. + */ + private def initializeStreamingShuffleOutputTracker(): Unit = { + if (_streamingShuffleOutputTracker.isDefined) { + return + } + + val shuffleManagerName = ShuffleManager.getShuffleManagerClassName(conf) + if (shuffleManagerName == classOf[StreamingShuffleManager].getName + || shuffleManagerName == classOf[MultiShuffleManager].getName) { + val tracker = if (SparkContext.isDriver(executorId)) { + new StreamingShuffleOutputTrackerMaster(conf) + } else { + new StreamingShuffleOutputTrackerWorker(conf) + } + + if (SparkContext.isDriver(executorId)) { + tracker.trackerEndpoint = rpcEnv.setupEndpoint( + StreamingShuffleOutputTracker.ENDPOINT_NAME, + new StreamingShuffleOutputTrackerMasterEndpoint( + rpcEnv, + tracker.asInstanceOf[StreamingShuffleOutputTrackerMaster], + conf)) + } else { + tracker.trackerEndpoint = RpcUtils.makeDriverRef( + StreamingShuffleOutputTracker.ENDPOINT_NAME, conf, rpcEnv) + } + _streamingShuffleOutputTracker = Some(tracker) + } } private[spark] def initializeMemoryManager(numUsableCores: Int): Unit = { diff --git a/core/src/main/scala/org/apache/spark/shuffle/streaming/MultiShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/streaming/MultiShuffleManager.scala new file mode 100644 index 0000000000000..3f6c74ff75009 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/streaming/MultiShuffleManager.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.streaming + +import java.util.Properties +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.{ShuffleBlockResolver, ShuffleHandle, ShuffleManager, ShuffleReader, ShuffleReadMetricsReporter, ShuffleWriteMetricsReporter, ShuffleWriter} +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.streaming.MultiShuffleManager.isStreamingShuffleEnabled + +class MultiShuffleHandle( + val streamingShuffleHandle: ShuffleHandle, + val otherShuffleHandle: ShuffleHandle) + extends ShuffleHandle(streamingShuffleHandle.shuffleId) + +object MultiShuffleManager { + val STREAMING_SHUFFLE_ENABLED_PROPERTY = "spark.shuffle.streaming.useForCurrentQuery" + + def isStreamingShuffleEnabled(properties: Properties): Boolean = + "true" == properties.getProperty(STREAMING_SHUFFLE_ENABLED_PROPERTY) +} + +/* This shuffle manager is used to allow real-time queries that depends on streaming shuffle +and normal queries that depends on sort shuffle to coexist in a cluster. Right now, we only +allows configuration of shuffle manager at cluster level, so consider using this shuffle +manager if you want to run batch and real time queries at the same time. + */ +class MultiShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + // To make sure the type of shuffle manager used for a shuffle is the same during its lifetime + private val shuffleIdToManager = new ConcurrentHashMap[Int, ShuffleManager]() + private var streamingShuffleManager: Option[StreamingShuffleManager] = None + private var sortShuffleManager: Option[SortShuffleManager] = None + + private def shuffleManager(shuffleId: Int): ShuffleManager = { + shuffleIdToManager.computeIfAbsent(shuffleId, _ => { + val properties = SparkContext.getActive.map(_.getLocalProperties) + .orElse(Option(TaskContext.get()).map(_.getLocalProperties)) + .getOrElse(throw SparkException.internalError( + "Cannot determine streaming shuffle routing: no active SparkContext or TaskContext")) + if (isStreamingShuffleEnabled(properties)) { + if (streamingShuffleManager.isEmpty) { + streamingShuffleManager = Some(new StreamingShuffleManager) + } + streamingShuffleManager.get + } else { + if (sortShuffleManager.isEmpty) { + sortShuffleManager = Some(new SortShuffleManager(conf)) + } + sortShuffleManager.get + } + }) + } + + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + shuffleIdToManager.synchronized { + shuffleManager(shuffleId).registerShuffle(shuffleId, dependency) + } + } + + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + shuffleIdToManager.synchronized { + shuffleManager(handle.shuffleId).getWriter(handle, mapId, context, metrics) + } + } + + override def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + shuffleIdToManager.synchronized { + shuffleManager(handle.shuffleId).getReader( + handle, + startMapIndex, + endMapIndex, + startPartition, + endPartition, + context, + metrics) + } + } + + override def unregisterShuffle(shuffleId: Int): Boolean = { + shuffleIdToManager.synchronized { + val manager = shuffleIdToManager.get(shuffleId) + // During unregistering shuffle, which happens when shuffleDependency is garbage + // collected, the context might not be active anymore, in this case, we will + // perform no-op since there is no cached shuffle manager, meaning + // there are no other calls (i.e registerShuffle, getWriter, or getReader) previously + // invoked, thereby no state to cleanup + if (manager == null) { + return true + } + + shuffleIdToManager.remove(shuffleId) + manager.unregisterShuffle(shuffleId) + } + } + + override def shuffleBlockResolver: ShuffleBlockResolver = { + shuffleIdToManager.synchronized { + if (sortShuffleManager.nonEmpty) { + sortShuffleManager.get.shuffleBlockResolver + } else { + // don't need to support this for the streaming shuffle implementation + // since block manager is not used + throw new UnsupportedOperationException() + } + } + } + + override def stop(): Unit = { + shuffleIdToManager.synchronized { + if (streamingShuffleManager.nonEmpty) { + streamingShuffleManager.get.stop() + } + if (sortShuffleManager.nonEmpty) { + sortShuffleManager.get.stop() + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManager.scala new file mode 100644 index 0000000000000..86566ce3f1675 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManager.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.streaming + +import org.apache.spark.{ShuffleDependency, SparkException, SparkRuntimeException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.network.shuffle.streaming.{DataMessage, StreamingShuffleMessage, StreamingShuffleMessageType, TerminationControlMessage} +import org.apache.spark.shuffle._ + +class StreamingShuffleHandle[K, V, C](shuffleId: Int, dependency: ShuffleDependency[K, V, C]) + extends BaseShuffleHandle[K, V, C](shuffleId, dependency) + +object StreamingShuffleManager extends Logging { + // Exposed for testing + private[spark] val QUERY_ID_PROPERTY_KEY = "sql.streaming.queryId" + // Since above is not applicable for batch query, we use below id to track error for batch + // query with streaming shuffle + private val QUERY_EXECUTION_ID_PROPERTY_KEY = "spark.sql.execution.id" + + def getQueryId(context: TaskContext): String = { + Option(context.getLocalProperty(QUERY_ID_PROPERTY_KEY)) + .orElse(Option(context.getLocalProperty(QUERY_EXECUTION_ID_PROPERTY_KEY))) + .getOrElse(throw SparkException.internalError( + "Streaming shuffle requires the query id or SQL execution id local property to be set")) + } + + /* Called from the reader side to get the writerId associated with a message */ + def getWriterId(message: StreamingShuffleMessage): Int = { + message.messageType() match { + case StreamingShuffleMessageType.DATA_MESSAGE_UNSAFE_ROW => + message.asInstanceOf[DataMessage].shuffleWriterId + case StreamingShuffleMessageType.TERMINATION_CONTROL_MESSAGE => + message.asInstanceOf[TerminationControlMessage].shuffleWriterId + case _ => + // Should not reach here + throw streamingShuffleUnexpectedMessageType(message.messageType()); + } + } + + def streamingShuffleIncorrectSequenceNumber( + messageType: StreamingShuffleMessageType, + writerId: Int, + readerId: Int, + expSeqNum: Long, + actSeqNum: Long): RuntimeException = { + new SparkRuntimeException( + errorClass = "STREAMING_SHUFFLE_INCORRECT_SEQUENCE_NUMBER", + messageParameters = Map( + "messageType" -> messageType.toString, + "writerId" -> writerId.toString, + "readerId" -> readerId.toString, + "expSeqNum" -> expSeqNum.toString, + "actSeqNum" -> actSeqNum.toString)) + } + + def streamingShuffleUnexpectedMessageType( + messageType: StreamingShuffleMessageType): RuntimeException = { + new SparkRuntimeException( + errorClass = "STREAMING_SHUFFLE_UNEXPECTED_MESSAGE_TYPE", + messageParameters = Map("messageType" -> messageType.toString)) + } +} + +private[spark] class StreamingShuffleManager extends ShuffleManager with Logging { + + logInfo(log"Using StreamingShuffleManager") + + /** + * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + new StreamingShuffleHandle(shuffleId, dependency) + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + // Implementation is added in a follow-up commit that introduces StreamingShuffleWriter. + throw new UnsupportedOperationException( + "StreamingShuffleManager.getWriter is not yet implemented") + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) + * to read from a range of map outputs(startMapIndex to endMapIndex-1, inclusive). If + * endMapIndex=Int.MaxValue, the actual endMapIndex will be changed to the length of total map + * outputs of the shuffle in `getMapSizesByExecutorId`. + * + * Called on executors by reduce tasks. + * + * For the streaming shuffle arguments startMapIndex, endMapIndex, startPartition, + * and endPartition are not relevant + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + // Implementation is added in a follow-up commit that introduces StreamingShuffleReader. + throw new UnsupportedOperationException( + "StreamingShuffleManager.getReader is not yet implemented") + } + + /** + * Remove a shuffle's metadata from the ShuffleManager. + * @return + * true if the metadata removed successfully, otherwise false. + */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + // No manager-side state to release here: the driver's StreamingShuffleOutputTracker is + // unregistered in BlockManagerStorageEndpoint's RemoveShuffle handler, and per-task writer + // and reader resources are released via task completion listeners. + true + } + + /** + * Return a resolver capable of retrieving shuffle block data based on block coordinates. + */ + override def shuffleBlockResolver: ShuffleBlockResolver = { + // don't need to support this for the streaming shuffle implementation + // since block manager is not used + throw new UnsupportedOperationException() + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = {} +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/streaming/TaskContextAwareLogging.scala b/core/src/main/scala/org/apache/spark/shuffle/streaming/TaskContextAwareLogging.scala new file mode 100644 index 0000000000000..fd0ac89abc79d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/streaming/TaskContextAwareLogging.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.streaming + +import scala.concurrent.duration.Duration + +import org.apache.spark.TaskContext +import org.apache.spark.internal.{LogEntry, Logging, LogKeys, MessageWithContext} + +trait TaskContextAwareLogging extends Logging { + + def context: TaskContext + + private val queryId: Option[String] = Option(context) + .flatMap(ctx => Option(ctx.getLocalProperty("sql.streaming.queryId")).map(_.take(5))) + .filter(_.nonEmpty) + + @volatile private var shuffleId: Option[Int] = None + + def setShuffleIdForLogging(shuffleId: Int): Unit = { + this.shuffleId = Some(shuffleId) + } + + private def loadTaskId: Option[String] = { + Option(context) + .flatMap(ctx => Option(ctx.partitionId())) + .map(_.toString) + } + + private def loadStageId: Option[String] = { + Option(context) + .flatMap(ctx => Option(ctx.stageId())) + .map(_.toString) + } + + protected def formatMessage( + msg: => String, + taskId: Option[String] = loadTaskId, + stageId: Option[String] = loadStageId): String = { + val taskIdMsg = taskId.map(tid => s"[taskId = $tid] ").getOrElse("") + val stageIdMsg = stageId.map(sid => s"[stageId = $sid] ").getOrElse("") + val shuffleIdMsg = shuffleId.map(shid => s"[shuffleId = $shid] ").getOrElse("") + val queryIdMsg = queryId.map(qid => s"[queryId = $qid] ").getOrElse("") + s"$queryIdMsg$shuffleIdMsg$stageIdMsg$taskIdMsg$msg" + } + + override protected def logInfo(msg: => String): Unit = + super.logInfo(formatMessage(msg)) + + override protected def logInfo(entry: LogEntry): Unit = + super.logInfo(log"${MDC(LogKeys.STREAMING_QUERY_ID, queryId.getOrElse(""))} " + + log"${MDC(LogKeys.SHUFFLE_ID, shuffleId.getOrElse(-1))} " + entry) + + override protected def logWarning(msg: => String): Unit = + super.logWarning(formatMessage(msg)) + + override protected def logWarning(entry: LogEntry): Unit = + super.logWarning(log"${MDC(LogKeys.STREAMING_QUERY_ID, queryId.getOrElse(""))} " + + log"${MDC(LogKeys.SHUFFLE_ID, shuffleId.getOrElse(-1))} " + entry) + + override protected def logDebug(msg: => String): Unit = + super.logDebug(formatMessage(msg)) + + override protected def logError(msg: => String): Unit = + super.logError(formatMessage(msg)) + + override protected def logError(entry: LogEntry): Unit = + super.logError(log"${MDC(LogKeys.STREAMING_QUERY_ID, queryId.getOrElse(""))} " + + log"${MDC(LogKeys.SHUFFLE_ID, shuffleId.getOrElse(-1))} " + entry) + + override protected def logError(entry: LogEntry, throwable: Throwable): Unit = + super.logError(log"${MDC(LogKeys.STREAMING_QUERY_ID, queryId.getOrElse(""))} " + + log"${MDC(LogKeys.SHUFFLE_ID, shuffleId.getOrElse(-1))} " + entry, throwable) + + override protected def logError(msg: => String, throwable: Throwable): Unit = + super.logError(formatMessage(msg), throwable) + + protected case class LogThrottler(logFn: String => Unit, interval: Duration) { + private var nextLogNanos = Long.MinValue + private var suppressed = 0 + + def apply(msg: => MessageWithContext): Unit = { + val now = System.nanoTime() + if (now >= nextLogNanos) { + val suffix = if (suppressed > 0) s" ($suppressed suppressed)" else "" + logFn(msg.message + suffix) + nextLogNanos = now + interval.toNanos + suppressed = 0 + } else { + suppressed += 1 + } + } + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/streaming/MultiShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/streaming/MultiShuffleManagerSuite.scala new file mode 100644 index 0000000000000..8d5e9fe4b17c3 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/streaming/MultiShuffleManagerSuite.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.streaming + +import java.util.Properties + +import org.scalatest.matchers.should.Matchers + +import org.apache.spark._ +import org.apache.spark.LocalSparkContext.withSpark +import org.apache.spark.internal.config.SHUFFLE_MANAGER +import org.apache.spark.shuffle.streaming.MultiShuffleManager.{isStreamingShuffleEnabled, STREAMING_SHUFFLE_ENABLED_PROPERTY} + +class MultiShuffleManagerSuite + extends SparkFunSuite + with LocalSparkContext + with Matchers { + + test("isStreamingShuffleEnabled reflects the per-query property") { + val props = new Properties() + isStreamingShuffleEnabled(props) should be(false) + + props.setProperty(STREAMING_SHUFFLE_ENABLED_PROPERTY, "true") + isStreamingShuffleEnabled(props) should be(true) + + props.setProperty(STREAMING_SHUFFLE_ENABLED_PROPERTY, "false") + isStreamingShuffleEnabled(props) should be(false) + } + + test("registerShuffle routes to the streaming manager when enabled for the query") { + withSpark(new SparkContext("local", "MultiShuffleManagerSuite", new SparkConf())) { sc => + sc.setLocalProperty(STREAMING_SHUFFLE_ENABLED_PROPERTY, "true") + val rdd = sc.parallelize(1 to 4).map(x => (x, x)) + val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(2)) + val handle = new MultiShuffleManager(sc.conf).registerShuffle(7, dep) + assert(handle.isInstanceOf[StreamingShuffleHandle[_, _, _]]) + } + } + + test("registerShuffle routes to the sort manager when not enabled for the query") { + withSpark(new SparkContext("local", "MultiShuffleManagerSuite", new SparkConf())) { sc => + val rdd = sc.parallelize(1 to 4).map(x => (x, x)) + val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(2)) + val handle = new MultiShuffleManager(sc.conf).registerShuffle(7, dep) + assert(!handle.isInstanceOf[StreamingShuffleHandle[_, _, _]]) + } + } + + test("SparkEnv initializes the streaming shuffle tracker when MultiShuffleManager is set") { + val conf = new SparkConf().set(SHUFFLE_MANAGER, classOf[MultiShuffleManager].getName) + withSpark(new SparkContext("local", "MultiShuffleManagerSuite", conf)) { _ => + assert(SparkEnv.get.streamingShuffleOutputTracker.isDefined) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManagerSuite.scala new file mode 100644 index 0000000000000..b93212e398a69 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManagerSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.streaming + +import io.netty.buffer.Unpooled +import org.mockito.Mockito.when +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark._ +import org.apache.spark.LocalSparkContext.withSpark +import org.apache.spark.internal.config.SHUFFLE_MANAGER +import org.apache.spark.network.shuffle.streaming.{DataMessage, TerminationAckMessage, TerminationControlMessage} +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.streaming.StreamingShuffleManager.{getQueryId, getWriterId, QUERY_ID_PROPERTY_KEY} + +class StreamingShuffleManagerSuite + extends SparkFunSuite + with LocalSparkContext + with Matchers + with MockitoSugar { + + private val SQL_EXECUTION_ID_KEY = "spark.sql.execution.id" + + // ---- getWriterId ---- + + test("getWriterId returns the writer id for a data message") { + val msg = new DataMessage(7, 3, 0, Unpooled.EMPTY_BUFFER, 0L) + getWriterId(msg) should be(7) + } + + test("getWriterId returns the writer id for a termination control message") { + getWriterId(new TerminationControlMessage(5, 2)) should be(5) + } + + test("getWriterId throws on an unexpected message type") { + val e = intercept[SparkRuntimeException] { + getWriterId(new TerminationAckMessage(1, 1)) + } + checkError( + e, + condition = "STREAMING_SHUFFLE_UNEXPECTED_MESSAGE_TYPE", + parameters = Map("messageType" -> "TERMINATION_ACK_MESSAGE")) + } + + // ---- getQueryId ---- + + test("getQueryId returns the streaming query id when set") { + val context = mock[TaskContext] + when(context.getLocalProperty(QUERY_ID_PROPERTY_KEY)).thenReturn("query-123") + getQueryId(context) should be("query-123") + } + + test("getQueryId falls back to the SQL execution id for batch queries") { + val context = mock[TaskContext] + when(context.getLocalProperty(SQL_EXECUTION_ID_KEY)).thenReturn("42") + getQueryId(context) should be("42") + } + + test("getQueryId throws when no query id property is set") { + val context = mock[TaskContext] + intercept[SparkException] { + getQueryId(context) + } + } + + // ---- registerShuffle ---- + + test("registerShuffle returns a StreamingShuffleHandle") { + withSpark(new SparkContext("local", "StreamingShuffleManagerSuite", new SparkConf())) { sc => + val rdd = sc.parallelize(1 to 4).map(x => (x, x)) + val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(2)) + val handle = new StreamingShuffleManager().registerShuffle(0, dep) + assert(handle.isInstanceOf[StreamingShuffleHandle[_, _, _]]) + } + } + + // ---- SparkEnv tracker initialization gating ---- + + test("SparkEnv initializes the streaming shuffle tracker when StreamingShuffleManager is set") { + val conf = new SparkConf().set(SHUFFLE_MANAGER, classOf[StreamingShuffleManager].getName) + withSpark(new SparkContext("local", "StreamingShuffleManagerSuite", conf)) { _ => + val tracker = SparkEnv.get.streamingShuffleOutputTracker + assert(tracker.isDefined) + // The driver hosts the master tracker. + assert(tracker.get.isInstanceOf[StreamingShuffleOutputTrackerMaster]) + } + } + + test("SparkEnv does not initialize the streaming shuffle tracker for a non-streaming manager") { + // A non-streaming shuffle manager (the default sort manager, configured explicitly) must + // not trigger StreamingShuffleOutputTracker initialization. + val conf = new SparkConf().set(SHUFFLE_MANAGER, classOf[SortShuffleManager].getName) + withSpark(new SparkContext("local", "StreamingShuffleManagerSuite", conf)) { _ => + assert(SparkEnv.get.streamingShuffleOutputTracker.isEmpty) + } + } + + test("SparkEnv does not initialize the streaming shuffle tracker for the default manager") { + withSpark(new SparkContext("local", "StreamingShuffleManagerSuite", new SparkConf())) { _ => + assert(SparkEnv.get.streamingShuffleOutputTracker.isEmpty) + } + } +}