Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ public enum LogKeys implements LogKey {
STREAMING_DATA_SOURCE_NAME,
STREAMING_OFFSETS_END,
STREAMING_OFFSETS_START,
STREAMING_QUERY_ID,
STREAMING_QUERY_PROGRESS,
STREAMING_SOURCE,
STREAMING_TABLE,
Expand Down
12 changes: 12 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -7130,6 +7130,18 @@
},
"sqlState" : "0A000"
},
"STREAMING_SHUFFLE_INCORRECT_SEQUENCE_NUMBER" : {
"message" : [
"Streaming shuffle <messageType> between writer <writerId> and reader <readerId> expected to have sequence number <expSeqNum>, but the actual sequence number is <actSeqNum>. Please verify that the messages are sent in order."
],
"sqlState" : "XXKST"
},
"STREAMING_SHUFFLE_UNEXPECTED_MESSAGE_TYPE" : {
"message" : [
"Unexpected message type <messageType> encountered during streaming shuffle."
],
"sqlState" : "XXKST"
},
"STREAMING_STATEFUL_OPERATOR_MISSING_STATE_DIRECTORY" : {
"message" : [
"Cannot restart streaming query with stateful operators because the state directory is empty or missing.",
Expand Down
49 changes: 49 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.shuffle.streaming.{MultiShuffleManager, StreamingShuffleManager}
import org.apache.spark.storage._
import org.apache.spark.udf.worker.UDFWorkerSpecification
import org.apache.spark.udf.worker.core.{UDFDispatcherFactory, UDFDispatcherManager, WorkerDispatcher}
Expand Down Expand Up @@ -181,6 +182,7 @@ class SparkEnv (
pythonWorkers.values.foreach(_.stop())
udfDispatcherManager.foreach(_.close())
mapOutputTracker.stop()
_streamingShuffleOutputTracker.foreach(_.stop())
if (shuffleManager != null) {
shuffleManager.stop()
}
Expand Down Expand Up @@ -299,6 +301,53 @@ class SparkEnv (
// Signal that the ShuffleManager has been initialized
shuffleManagerInitLatch.countDown()
}
initializeStreamingShuffleOutputTracker()
}

// Holds the streaming shuffle output tracker, which is only present when the configured
// shuffle manager requires it (i.e., StreamingShuffleManager or MultiShuffleManager).
@volatile private var _streamingShuffleOutputTracker: Option[StreamingShuffleOutputTracker] =
None

def streamingShuffleOutputTracker: Option[StreamingShuffleOutputTracker] =
_streamingShuffleOutputTracker

/**
* Initialize the StreamingShuffleOutputTracker if the configured shuffle manager requires one
* and one does not already exist. This method is idempotent -- calling it multiple times is safe.
*
* This is separated from initializeShuffleManager() to allow the driver to register the
* tracker's RPC endpoint before the SHUFFLE_MANAGER config change propagates to executors,
* eliminating the race condition where executors try to look up the endpoint before the driver
* has registered it.
*/
private def initializeStreamingShuffleOutputTracker(): Unit = {
if (_streamingShuffleOutputTracker.isDefined) {
return
}

val shuffleManagerName = ShuffleManager.getShuffleManagerClassName(conf)
if (shuffleManagerName == classOf[StreamingShuffleManager].getName
|| shuffleManagerName == classOf[MultiShuffleManager].getName) {
val tracker = if (SparkContext.isDriver(executorId)) {
new StreamingShuffleOutputTrackerMaster(conf)
} else {
new StreamingShuffleOutputTrackerWorker(conf)
}

if (SparkContext.isDriver(executorId)) {
tracker.trackerEndpoint = rpcEnv.setupEndpoint(
StreamingShuffleOutputTracker.ENDPOINT_NAME,
new StreamingShuffleOutputTrackerMasterEndpoint(
rpcEnv,
tracker.asInstanceOf[StreamingShuffleOutputTrackerMaster],
conf))
} else {
tracker.trackerEndpoint = RpcUtils.makeDriverRef(
StreamingShuffleOutputTracker.ENDPOINT_NAME, conf, rpcEnv)
}
_streamingShuffleOutputTracker = Some(tracker)
}
}

private[spark] def initializeMemoryManager(numUsableCores: Int): Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.streaming

import java.util.Properties
import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.{ShuffleBlockResolver, ShuffleHandle, ShuffleManager, ShuffleReader, ShuffleReadMetricsReporter, ShuffleWriteMetricsReporter, ShuffleWriter}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.shuffle.streaming.MultiShuffleManager.isStreamingShuffleEnabled

class MultiShuffleHandle(
val streamingShuffleHandle: ShuffleHandle,
val otherShuffleHandle: ShuffleHandle)
extends ShuffleHandle(streamingShuffleHandle.shuffleId)

object MultiShuffleManager {
val STREAMING_SHUFFLE_ENABLED_PROPERTY = "spark.shuffle.streaming.useForCurrentQuery"

def isStreamingShuffleEnabled(properties: Properties): Boolean =
"true" == properties.getProperty(STREAMING_SHUFFLE_ENABLED_PROPERTY)
}

/* This shuffle manager is used to allow real-time queries that depends on streaming shuffle
and normal queries that depends on sort shuffle to coexist in a cluster. Right now, we only
allows configuration of shuffle manager at cluster level, so consider using this shuffle
manager if you want to run batch and real time queries at the same time.
*/
class MultiShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
// To make sure the type of shuffle manager used for a shuffle is the same during its lifetime
private val shuffleIdToManager = new ConcurrentHashMap[Int, ShuffleManager]()
private var streamingShuffleManager: Option[StreamingShuffleManager] = None
private var sortShuffleManager: Option[SortShuffleManager] = None

private def shuffleManager(shuffleId: Int): ShuffleManager = {
shuffleIdToManager.computeIfAbsent(shuffleId, _ => {
val properties = SparkContext.getActive.map(_.getLocalProperties)
.orElse(Option(TaskContext.get()).map(_.getLocalProperties))
.getOrElse(throw SparkException.internalError(
"Cannot determine streaming shuffle routing: no active SparkContext or TaskContext"))
if (isStreamingShuffleEnabled(properties)) {
if (streamingShuffleManager.isEmpty) {
streamingShuffleManager = Some(new StreamingShuffleManager)
}
streamingShuffleManager.get
} else {
if (sortShuffleManager.isEmpty) {
sortShuffleManager = Some(new SortShuffleManager(conf))
}
sortShuffleManager.get
}
})
}

override def registerShuffle[K, V, C](
shuffleId: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
shuffleIdToManager.synchronized {
shuffleManager(shuffleId).registerShuffle(shuffleId, dependency)
}
}

override def getWriter[K, V](
handle: ShuffleHandle,
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
shuffleIdToManager.synchronized {
shuffleManager(handle.shuffleId).getWriter(handle, mapId, context, metrics)
}
}

override def getReader[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
shuffleIdToManager.synchronized {
shuffleManager(handle.shuffleId).getReader(
handle,
startMapIndex,
endMapIndex,
startPartition,
endPartition,
context,
metrics)
}
}

override def unregisterShuffle(shuffleId: Int): Boolean = {
shuffleIdToManager.synchronized {
val manager = shuffleIdToManager.get(shuffleId)
// During unregistering shuffle, which happens when shuffleDependency is garbage
// collected, the context might not be active anymore, in this case, we will
// perform no-op since there is no cached shuffle manager, meaning
// there are no other calls (i.e registerShuffle, getWriter, or getReader) previously
// invoked, thereby no state to cleanup
if (manager == null) {
return true
}

shuffleIdToManager.remove(shuffleId)
manager.unregisterShuffle(shuffleId)
}
}

override def shuffleBlockResolver: ShuffleBlockResolver = {
shuffleIdToManager.synchronized {
if (sortShuffleManager.nonEmpty) {
sortShuffleManager.get.shuffleBlockResolver
} else {
// don't need to support this for the streaming shuffle implementation
// since block manager is not used
throw new UnsupportedOperationException()
}
}
}

override def stop(): Unit = {
shuffleIdToManager.synchronized {
if (streamingShuffleManager.nonEmpty) {
streamingShuffleManager.get.stop()
}
if (sortShuffleManager.nonEmpty) {
sortShuffleManager.get.stop()
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.streaming

import org.apache.spark.{ShuffleDependency, SparkException, SparkRuntimeException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.network.shuffle.streaming.{DataMessage, StreamingShuffleMessage, StreamingShuffleMessageType, TerminationControlMessage}
import org.apache.spark.shuffle._

class StreamingShuffleHandle[K, V, C](shuffleId: Int, dependency: ShuffleDependency[K, V, C])
extends BaseShuffleHandle[K, V, C](shuffleId, dependency)

object StreamingShuffleManager extends Logging {
// Exposed for testing
private[spark] val QUERY_ID_PROPERTY_KEY = "sql.streaming.queryId"
// Since above is not applicable for batch query, we use below id to track error for batch
// query with streaming shuffle
private val QUERY_EXECUTION_ID_PROPERTY_KEY = "spark.sql.execution.id"

def getQueryId(context: TaskContext): String = {
Option(context.getLocalProperty(QUERY_ID_PROPERTY_KEY))
.orElse(Option(context.getLocalProperty(QUERY_EXECUTION_ID_PROPERTY_KEY)))
.getOrElse(throw SparkException.internalError(
"Streaming shuffle requires the query id or SQL execution id local property to be set"))
}

/* Called from the reader side to get the writerId associated with a message */
def getWriterId(message: StreamingShuffleMessage): Int = {
message.messageType() match {
case StreamingShuffleMessageType.DATA_MESSAGE_UNSAFE_ROW =>
message.asInstanceOf[DataMessage].shuffleWriterId
case StreamingShuffleMessageType.TERMINATION_CONTROL_MESSAGE =>
message.asInstanceOf[TerminationControlMessage].shuffleWriterId
case _ =>
// Should not reach here
throw streamingShuffleUnexpectedMessageType(message.messageType());
}
}

def streamingShuffleIncorrectSequenceNumber(
messageType: StreamingShuffleMessageType,
writerId: Int,
readerId: Int,
expSeqNum: Long,
actSeqNum: Long): RuntimeException = {
new SparkRuntimeException(
errorClass = "STREAMING_SHUFFLE_INCORRECT_SEQUENCE_NUMBER",
messageParameters = Map(
"messageType" -> messageType.toString,
"writerId" -> writerId.toString,
"readerId" -> readerId.toString,
"expSeqNum" -> expSeqNum.toString,
"actSeqNum" -> actSeqNum.toString))
}

def streamingShuffleUnexpectedMessageType(
messageType: StreamingShuffleMessageType): RuntimeException = {
new SparkRuntimeException(
errorClass = "STREAMING_SHUFFLE_UNEXPECTED_MESSAGE_TYPE",
messageParameters = Map("messageType" -> messageType.toString))
}
}

private[spark] class StreamingShuffleManager extends ShuffleManager with Logging {

logInfo(log"Using StreamingShuffleManager")

/**
* Register a shuffle with the manager and obtain a handle for it to pass to tasks.
*/
override def registerShuffle[K, V, C](
shuffleId: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
new StreamingShuffleHandle(shuffleId, dependency)
}

/** Get a writer for a given partition. Called on executors by map tasks. */
override def getWriter[K, V](
handle: ShuffleHandle,
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
// Implementation is added in a follow-up commit that introduces StreamingShuffleWriter.
throw new UnsupportedOperationException(
"StreamingShuffleManager.getWriter is not yet implemented")
}

/**
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive)
* to read from a range of map outputs(startMapIndex to endMapIndex-1, inclusive). If
* endMapIndex=Int.MaxValue, the actual endMapIndex will be changed to the length of total map
* outputs of the shuffle in `getMapSizesByExecutorId`.
*
* Called on executors by reduce tasks.
*
* For the streaming shuffle arguments startMapIndex, endMapIndex, startPartition,
* and endPartition are not relevant
*/
override def getReader[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
// Implementation is added in a follow-up commit that introduces StreamingShuffleReader.
throw new UnsupportedOperationException(
"StreamingShuffleManager.getReader is not yet implemented")
}

/**
* Remove a shuffle's metadata from the ShuffleManager.
* @return
* true if the metadata removed successfully, otherwise false.
*/
override def unregisterShuffle(shuffleId: Int): Boolean = {
// No manager-side state to release here: the driver's StreamingShuffleOutputTracker is
// unregistered in BlockManagerStorageEndpoint's RemoveShuffle handler, and per-task writer
// and reader resources are released via task completion listeners.
true
}

/**
* Return a resolver capable of retrieving shuffle block data based on block coordinates.
*/
override def shuffleBlockResolver: ShuffleBlockResolver = {
// don't need to support this for the streaming shuffle implementation
// since block manager is not used
throw new UnsupportedOperationException()
}

/** Shut down this ShuffleManager. */
override def stop(): Unit = {}
}
Loading