diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 34ff34c26eb4b..46b2cde2eed95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -62,13 +62,10 @@ case class SortExec( "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) - // Each task thread has its own UnsafeExternalRowSorter instance stored here. - // Using a stable lazy val (rather than a reassigned var) ensures that the ThreadLocal - // object itself is never replaced: concurrent tasks on different threads each get their - // own independent slot in the same ThreadLocal, so one task can never observe or clobber - // another task's sorter reference. - @transient private[sql] lazy val rowSorter: ThreadLocal[UnsafeExternalRowSorter] = - new ThreadLocal[UnsafeExternalRowSorter]() + // WARNING: This is a shared mutable var on the SortExec instance. Do not access it from + // multiple threads concurrently - Spark operators do not guarantee thread-safety and one + // task's sorter could overwrite another's, causing a race condition. + private[sql] var rowSorter: UnsafeExternalRowSorter = _ /** * This method gets invoked only once for each SortExec instance to initialize an @@ -101,14 +98,13 @@ case class SortExec( } val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - val newRowSorter = UnsafeExternalRowSorter.create( + rowSorter = UnsafeExternalRowSorter.create( schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort) if (testSpillFrequency > 0) { - newRowSorter.setTestSpillFrequency(testSpillFrequency) + rowSorter.setTestSpillFrequency(testSpillFrequency) } - rowSorter.set(newRowSorter) - rowSorter.get() + rowSorter } protected override def doExecute(): RDD[InternalRow] = { @@ -204,8 +200,8 @@ case class SortExec( * cleanupResources before rowSorter is initialized in createSorter. */ override protected[sql] def cleanupResources(): Unit = { - if (rowSorter.get() != null) { - rowSorter.get().cleanupResources() + if (rowSorter != null) { + rowSorter.cleanupResources() } super.cleanupResources() }