Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
import org.apache.spark.util.{RpcUtils, Utils}
import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils}
import org.apache.spark.util.ArrayImplicits._

/**
Expand Down Expand Up @@ -120,6 +120,14 @@ class SparkEnv (
pythonExec: String, workerModule: String, daemonModule: String, envVars: Map[String, String])
private val pythonWorkers = mutable.HashMap[PythonWorkersKey, PythonWorkerFactory]()

private val idleFactoryReaper =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("idle-python-factory-reaper")
idleFactoryReaper.scheduleAtFixedRate(
() => evictIdlePythonWorkerFactories(),
PythonWorkerFactory.IDLE_FACTORY_CHECK_INTERVAL_MS,
PythonWorkerFactory.IDLE_FACTORY_CHECK_INTERVAL_MS,
TimeUnit.MILLISECONDS)

// A general, soft-reference map for metadata needed during HadoopRDD split computation
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
private[spark] val hadoopJobMetadata =
Expand All @@ -133,6 +141,7 @@ class SparkEnv (

if (!isStopped) {
isStopped = true
idleFactoryReaper.shutdown()
pythonWorkers.values.foreach(_.stop())
mapOutputTracker.stop()
if (shuffleManager != null) {
Expand Down Expand Up @@ -244,6 +253,39 @@ class SparkEnv (
pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars, worker)
}

/**
* Stop and remove all [[PythonWorkerFactory]] instances whose environment contains the given
* `SPARK_JOB_ARTIFACT_UUID`. This is called when a Spark Connect session closes so that
* per-session daemon processes do not leak.
*/
private[spark] def destroyPythonWorkersByArtifactUUID(uuid: String): Unit = {
synchronized {
val keysToRemove = pythonWorkers.collect {
case (key, _) if key.envVars.get("SPARK_JOB_ARTIFACT_UUID").contains(uuid) => key
}.toSeq
keysToRemove.foreach { key =>
pythonWorkers.remove(key).foreach(_.stop())
}
}
}

/**
* Evict [[PythonWorkerFactory]] instances that belong to a specific Spark Connect session
* (non-default artifact UUID) and have been idle for longer than the factory idle timeout.
* This handles the executor-side cleanup where session close notifications are not received.
*/
private def evictIdlePythonWorkerFactories(): Unit = {
synchronized {
val keysToRemove = pythonWorkers.collect {
case (key, factory) if factory.isIdleFactory(
PythonWorkerFactory.IDLE_FACTORY_TIMEOUT_NS) => key
}.toSeq
keysToRemove.foreach { key =>
pythonWorkers.remove(key).foreach(_.stop())
}
}
}

private[spark] def initializeShuffleManager(): Unit = {
Preconditions.checkState(null == _shuffleManager,
"Shuffle manager already initialized to %s", _shuffleManager)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,23 @@ private[spark] class PythonWorkerFactory(
private val maxIdleWorkerPoolSize =
conf.get(PYTHON_FACTORY_IDLE_WORKER_MAX_POOL_SIZE)
@GuardedBy("self")
private var lastActivityNs = 0L
private var lastActivityNs = System.nanoTime()
new MonitorThread().start()

private[spark] val jobArtifactUUID: String =
envVars.getOrElse("SPARK_JOB_ARTIFACT_UUID", "default")

/**
* Returns true if this factory has a non-default artifact UUID (i.e. it belongs to a
* specific Spark Connect session) and has had no activity for longer than the given timeout.
*/
private[spark] def isIdleFactory(timeoutNs: Long): Boolean = self.synchronized {
jobArtifactUUID != "default" &&
idleWorkers.isEmpty &&
daemonWorkers.isEmpty &&
(System.nanoTime() - lastActivityNs) > timeoutNs
}

@GuardedBy("self")
private val simpleWorkers = new mutable.WeakHashMap[PythonWorker, Process]()

Expand Down Expand Up @@ -544,5 +558,11 @@ private[spark] object PythonWorkerFactory {
val PROCESS_WAIT_TIMEOUT_MS = 10000
val IDLE_WORKER_TIMEOUT_NS = TimeUnit.MINUTES.toNanos(1) // kill idle workers after 1 minute

// Timeout for evicting entire PythonWorkerFactory instances that belong to closed
// Spark Connect sessions. Factories with a non-default SPARK_JOB_ARTIFACT_UUID that
// have been idle for longer than this are stopped and removed from the cache.
val IDLE_FACTORY_TIMEOUT_NS = TimeUnit.MINUTES.toNanos(5)
val IDLE_FACTORY_CHECK_INTERVAL_MS = 60000L // check every 60 seconds

private[spark] val defaultDaemonModule = "pyspark.daemon"
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,113 @@ class PythonWorkerFactorySuite extends SparkFunSuite with SharedSparkContext {
}
}
}

class PythonWorkerFactoryIdleSuite extends SparkFunSuite with SharedSparkContext {

test("isIdleFactory returns false for default artifact UUID") {
val factory = new PythonWorkerFactory(
"python3", "pyspark.worker", Map.empty[String, String], true)
try {
assert(factory.jobArtifactUUID === "default")
assert(!factory.isIdleFactory(0))
} finally {
factory.stop()
}
}

test("isIdleFactory returns false for session factory with recent activity") {
val envVars = Map("SPARK_JOB_ARTIFACT_UUID" -> "test-session-uuid")
val factory = new PythonWorkerFactory(
"python3", "pyspark.worker", envVars, true)
try {
assert(factory.jobArtifactUUID === "test-session-uuid")
assert(!factory.isIdleFactory(java.util.concurrent.TimeUnit.HOURS.toNanos(1)))
} finally {
factory.stop()
}
}

test("isIdleFactory returns true for session factory past timeout") {
val envVars = Map("SPARK_JOB_ARTIFACT_UUID" -> "test-session-uuid")
val factory = new PythonWorkerFactory(
"python3", "pyspark.worker", envVars, true)
try {
assert(factory.jobArtifactUUID === "test-session-uuid")
assert(factory.isIdleFactory(0))
} finally {
factory.stop()
}
}

test("destroyPythonWorkersByArtifactUUID removes only matching factories") {
val env = sc.env
val uuid1 = "session-uuid-1"
val uuid2 = "session-uuid-2"
val envVars1 = Map("SPARK_JOB_ARTIFACT_UUID" -> uuid1)
val envVars2 = Map("SPARK_JOB_ARTIFACT_UUID" -> uuid2)
val defaultEnvVars = Map("SPARK_JOB_ARTIFACT_UUID" -> "default")

// Access the internal cache via reflection
val pythonWorkersField = env.getClass.getDeclaredField("pythonWorkers")
pythonWorkersField.setAccessible(true)
val rawMap = pythonWorkersField.get(env)
val putMethod = rawMap.getClass.getMethod("put", classOf[Object], classOf[Object])
val sizeMethod = rawMap.getClass.getMethod("size")
def mapSize(): Int = sizeMethod.invoke(rawMap).asInstanceOf[Int]
def factoryValues(): Iterable[PythonWorkerFactory] = {
val valuesMethod = rawMap.getClass.getMethod("values")
val values = valuesMethod.invoke(rawMap)
.asInstanceOf[Iterable[PythonWorkerFactory]]
values
}

val sizeBefore = mapSize()

val factory1 = new PythonWorkerFactory("python3", "pyspark.worker", envVars1, true)
val factory2 = new PythonWorkerFactory("python3", "pyspark.worker", envVars2, true)
val factoryDefault = new PythonWorkerFactory(
"python3", "pyspark.worker", defaultEnvVars, true)

// Construct keys via reflection (PythonWorkersKey is private)
val keyClass = env.getClass.getDeclaredClasses
.find(_.getSimpleName.contains("PythonWorkersKey")).get
val keyConstructor = keyClass.getDeclaredConstructors.head
keyConstructor.setAccessible(true)
def makeKey(envVars: Map[String, String]): AnyRef =
keyConstructor.newInstance(
env, "python3", "pyspark.worker",
PythonWorkerFactory.defaultDaemonModule, envVars).asInstanceOf[AnyRef]

val key1 = makeKey(envVars1)
val key2 = makeKey(envVars2)
val keyDefault = makeKey(defaultEnvVars)

try {
putMethod.invoke(rawMap, key1, factory1)
putMethod.invoke(rawMap, key2, factory2)
putMethod.invoke(rawMap, keyDefault, factoryDefault)
assert(mapSize() === sizeBefore + 3)

// Destroy factories for uuid1 only
env.destroyPythonWorkersByArtifactUUID(uuid1)
assert(mapSize() === sizeBefore + 2)
assert(factoryValues().exists(_.jobArtifactUUID == uuid2))
assert(factoryValues().exists(_.jobArtifactUUID == "default"))
assert(!factoryValues().exists(_.jobArtifactUUID == uuid1))

// Destroy factories for uuid2
env.destroyPythonWorkersByArtifactUUID(uuid2)
assert(mapSize() === sizeBefore + 1)
assert(!factoryValues().exists(_.jobArtifactUUID == uuid2))
assert(factoryValues().exists(_.jobArtifactUUID == "default"))
} finally {
val removeMethod = rawMap.getClass.getMethod("remove", classOf[Object])
Seq(key1, key2, keyDefault).foreach { key =>
val removed = removeMethod.invoke(rawMap, key)
if (removed != null) {
removed.asInstanceOf[Option[PythonWorkerFactory]].foreach(_.stop())
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
// Clean up ML cache (only if ML models were created)
mlCache.close()

session.cleanupPythonWorkers()
session.cleanupPythonWorkerLogs()

eventManager.postClosed()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,16 @@ class SparkSession private(
private[sql] def cleanupPythonWorkerLogs(): Unit = {
PythonSQLUtils.cleanupPythonWorkerLogs(sessionUUID, sparkContext)
}

/**
* Stops and removes all PythonWorkerFactory instances associated with this session's
* artifact UUID. Prevents daemon process leaks when Spark Connect sessions are closed.
*/
private[sql] def cleanupPythonWorkers(): Unit = {
if (!sparkContext.isStopped) {
sparkContext.env.destroyPythonWorkersByArtifactUUID(sessionUUID)
}
}
}


Expand Down