diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 7dcf66a609577..228fa5ebaaf95 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -47,7 +47,7 @@ import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ -import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} import org.apache.spark.util.ArrayImplicits._ /** @@ -120,6 +120,14 @@ class SparkEnv ( pythonExec: String, workerModule: String, daemonModule: String, envVars: Map[String, String]) private val pythonWorkers = mutable.HashMap[PythonWorkersKey, PythonWorkerFactory]() + private val idleFactoryReaper = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("idle-python-factory-reaper") + idleFactoryReaper.scheduleAtFixedRate( + () => evictIdlePythonWorkerFactories(), + PythonWorkerFactory.IDLE_FACTORY_CHECK_INTERVAL_MS, + PythonWorkerFactory.IDLE_FACTORY_CHECK_INTERVAL_MS, + TimeUnit.MILLISECONDS) + // A general, soft-reference map for metadata needed during HadoopRDD split computation // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). private[spark] val hadoopJobMetadata = @@ -133,6 +141,7 @@ class SparkEnv ( if (!isStopped) { isStopped = true + idleFactoryReaper.shutdown() pythonWorkers.values.foreach(_.stop()) mapOutputTracker.stop() if (shuffleManager != null) { @@ -244,6 +253,39 @@ class SparkEnv ( pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars, worker) } + /** + * Stop and remove all [[PythonWorkerFactory]] instances whose environment contains the given + * `SPARK_JOB_ARTIFACT_UUID`. This is called when a Spark Connect session closes so that + * per-session daemon processes do not leak. + */ + private[spark] def destroyPythonWorkersByArtifactUUID(uuid: String): Unit = { + synchronized { + val keysToRemove = pythonWorkers.collect { + case (key, _) if key.envVars.get("SPARK_JOB_ARTIFACT_UUID").contains(uuid) => key + }.toSeq + keysToRemove.foreach { key => + pythonWorkers.remove(key).foreach(_.stop()) + } + } + } + + /** + * Evict [[PythonWorkerFactory]] instances that belong to a specific Spark Connect session + * (non-default artifact UUID) and have been idle for longer than the factory idle timeout. + * This handles the executor-side cleanup where session close notifications are not received. + */ + private def evictIdlePythonWorkerFactories(): Unit = { + synchronized { + val keysToRemove = pythonWorkers.collect { + case (key, factory) if factory.isIdleFactory( + PythonWorkerFactory.IDLE_FACTORY_TIMEOUT_NS) => key + }.toSeq + keysToRemove.foreach { key => + pythonWorkers.remove(key).foreach(_.stop()) + } + } + } + private[spark] def initializeShuffleManager(): Unit = { Preconditions.checkState(null == _shuffleManager, "Shuffle manager already initialized to %s", _shuffleManager) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 350818e18cb98..656db4ca5c0fd 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -118,9 +118,23 @@ private[spark] class PythonWorkerFactory( private val maxIdleWorkerPoolSize = conf.get(PYTHON_FACTORY_IDLE_WORKER_MAX_POOL_SIZE) @GuardedBy("self") - private var lastActivityNs = 0L + private var lastActivityNs = System.nanoTime() new MonitorThread().start() + private[spark] val jobArtifactUUID: String = + envVars.getOrElse("SPARK_JOB_ARTIFACT_UUID", "default") + + /** + * Returns true if this factory has a non-default artifact UUID (i.e. it belongs to a + * specific Spark Connect session) and has had no activity for longer than the given timeout. + */ + private[spark] def isIdleFactory(timeoutNs: Long): Boolean = self.synchronized { + jobArtifactUUID != "default" && + idleWorkers.isEmpty && + daemonWorkers.isEmpty && + (System.nanoTime() - lastActivityNs) > timeoutNs + } + @GuardedBy("self") private val simpleWorkers = new mutable.WeakHashMap[PythonWorker, Process]() @@ -544,5 +558,11 @@ private[spark] object PythonWorkerFactory { val PROCESS_WAIT_TIMEOUT_MS = 10000 val IDLE_WORKER_TIMEOUT_NS = TimeUnit.MINUTES.toNanos(1) // kill idle workers after 1 minute + // Timeout for evicting entire PythonWorkerFactory instances that belong to closed + // Spark Connect sessions. Factories with a non-default SPARK_JOB_ARTIFACT_UUID that + // have been idle for longer than this are stopped and removed from the cache. + val IDLE_FACTORY_TIMEOUT_NS = TimeUnit.MINUTES.toNanos(5) + val IDLE_FACTORY_CHECK_INTERVAL_MS = 60000L // check every 60 seconds + private[spark] val defaultDaemonModule = "pyspark.daemon" } diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonWorkerFactorySuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonWorkerFactorySuite.scala index 4f9dafb6cbeae..a1fbfdcf7f06e 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonWorkerFactorySuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonWorkerFactorySuite.scala @@ -110,3 +110,113 @@ class PythonWorkerFactorySuite extends SparkFunSuite with SharedSparkContext { } } } + +class PythonWorkerFactoryIdleSuite extends SparkFunSuite with SharedSparkContext { + + test("isIdleFactory returns false for default artifact UUID") { + val factory = new PythonWorkerFactory( + "python3", "pyspark.worker", Map.empty[String, String], true) + try { + assert(factory.jobArtifactUUID === "default") + assert(!factory.isIdleFactory(0)) + } finally { + factory.stop() + } + } + + test("isIdleFactory returns false for session factory with recent activity") { + val envVars = Map("SPARK_JOB_ARTIFACT_UUID" -> "test-session-uuid") + val factory = new PythonWorkerFactory( + "python3", "pyspark.worker", envVars, true) + try { + assert(factory.jobArtifactUUID === "test-session-uuid") + assert(!factory.isIdleFactory(java.util.concurrent.TimeUnit.HOURS.toNanos(1))) + } finally { + factory.stop() + } + } + + test("isIdleFactory returns true for session factory past timeout") { + val envVars = Map("SPARK_JOB_ARTIFACT_UUID" -> "test-session-uuid") + val factory = new PythonWorkerFactory( + "python3", "pyspark.worker", envVars, true) + try { + assert(factory.jobArtifactUUID === "test-session-uuid") + assert(factory.isIdleFactory(0)) + } finally { + factory.stop() + } + } + + test("destroyPythonWorkersByArtifactUUID removes only matching factories") { + val env = sc.env + val uuid1 = "session-uuid-1" + val uuid2 = "session-uuid-2" + val envVars1 = Map("SPARK_JOB_ARTIFACT_UUID" -> uuid1) + val envVars2 = Map("SPARK_JOB_ARTIFACT_UUID" -> uuid2) + val defaultEnvVars = Map("SPARK_JOB_ARTIFACT_UUID" -> "default") + + // Access the internal cache via reflection + val pythonWorkersField = env.getClass.getDeclaredField("pythonWorkers") + pythonWorkersField.setAccessible(true) + val rawMap = pythonWorkersField.get(env) + val putMethod = rawMap.getClass.getMethod("put", classOf[Object], classOf[Object]) + val sizeMethod = rawMap.getClass.getMethod("size") + def mapSize(): Int = sizeMethod.invoke(rawMap).asInstanceOf[Int] + def factoryValues(): Iterable[PythonWorkerFactory] = { + val valuesMethod = rawMap.getClass.getMethod("values") + val values = valuesMethod.invoke(rawMap) + .asInstanceOf[Iterable[PythonWorkerFactory]] + values + } + + val sizeBefore = mapSize() + + val factory1 = new PythonWorkerFactory("python3", "pyspark.worker", envVars1, true) + val factory2 = new PythonWorkerFactory("python3", "pyspark.worker", envVars2, true) + val factoryDefault = new PythonWorkerFactory( + "python3", "pyspark.worker", defaultEnvVars, true) + + // Construct keys via reflection (PythonWorkersKey is private) + val keyClass = env.getClass.getDeclaredClasses + .find(_.getSimpleName.contains("PythonWorkersKey")).get + val keyConstructor = keyClass.getDeclaredConstructors.head + keyConstructor.setAccessible(true) + def makeKey(envVars: Map[String, String]): AnyRef = + keyConstructor.newInstance( + env, "python3", "pyspark.worker", + PythonWorkerFactory.defaultDaemonModule, envVars).asInstanceOf[AnyRef] + + val key1 = makeKey(envVars1) + val key2 = makeKey(envVars2) + val keyDefault = makeKey(defaultEnvVars) + + try { + putMethod.invoke(rawMap, key1, factory1) + putMethod.invoke(rawMap, key2, factory2) + putMethod.invoke(rawMap, keyDefault, factoryDefault) + assert(mapSize() === sizeBefore + 3) + + // Destroy factories for uuid1 only + env.destroyPythonWorkersByArtifactUUID(uuid1) + assert(mapSize() === sizeBefore + 2) + assert(factoryValues().exists(_.jobArtifactUUID == uuid2)) + assert(factoryValues().exists(_.jobArtifactUUID == "default")) + assert(!factoryValues().exists(_.jobArtifactUUID == uuid1)) + + // Destroy factories for uuid2 + env.destroyPythonWorkersByArtifactUUID(uuid2) + assert(mapSize() === sizeBefore + 1) + assert(!factoryValues().exists(_.jobArtifactUUID == uuid2)) + assert(factoryValues().exists(_.jobArtifactUUID == "default")) + } finally { + val removeMethod = rawMap.getClass.getMethod("remove", classOf[Object]) + Seq(key1, key2, keyDefault).foreach { key => + val removed = removeMethod.invoke(rawMap, key) + if (removed != null) { + removed.asInstanceOf[Option[PythonWorkerFactory]].foreach(_.stop()) + } + } + } + } +} diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 307416a659f7a..ebed9a3ea3006 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -418,6 +418,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio // Clean up ML cache (only if ML models were created) mlCache.close() + session.cleanupPythonWorkers() session.cleanupPythonWorkerLogs() eventManager.postClosed() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala index f03b4796314b7..533fd90eaea4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala @@ -924,6 +924,16 @@ class SparkSession private( private[sql] def cleanupPythonWorkerLogs(): Unit = { PythonSQLUtils.cleanupPythonWorkerLogs(sessionUUID, sparkContext) } + + /** + * Stops and removes all PythonWorkerFactory instances associated with this session's + * artifact UUID. Prevents daemon process leaks when Spark Connect sessions are closed. + */ + private[sql] def cleanupPythonWorkers(): Unit = { + if (!sparkContext.isStopped) { + sparkContext.env.destroyPythonWorkersByArtifactUUID(sessionUUID) + } + } }