Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,10 @@ case class SortExec(
"peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))

// Each task thread has its own UnsafeExternalRowSorter instance stored here.
// Using a stable lazy val (rather than a reassigned var) ensures that the ThreadLocal
// object itself is never replaced: concurrent tasks on different threads each get their
// own independent slot in the same ThreadLocal, so one task can never observe or clobber
// another task's sorter reference.
@transient private[sql] lazy val rowSorter: ThreadLocal[UnsafeExternalRowSorter] =
new ThreadLocal[UnsafeExternalRowSorter]()
// WARNING: This is a shared mutable var on the SortExec instance. Do not access it from
// multiple threads concurrently - Spark operators do not guarantee thread-safety and one
// task's sorter could overwrite another's, causing a race condition.
private[sql] var rowSorter: UnsafeExternalRowSorter = _

/**
* This method gets invoked only once for each SortExec instance to initialize an
Expand Down Expand Up @@ -101,14 +98,13 @@ case class SortExec(
}

val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
val newRowSorter = UnsafeExternalRowSorter.create(
rowSorter = UnsafeExternalRowSorter.create(
schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort)

if (testSpillFrequency > 0) {
newRowSorter.setTestSpillFrequency(testSpillFrequency)
rowSorter.setTestSpillFrequency(testSpillFrequency)
}
rowSorter.set(newRowSorter)
rowSorter.get()
rowSorter
}

protected override def doExecute(): RDD[InternalRow] = {
Expand Down Expand Up @@ -204,8 +200,8 @@ case class SortExec(
* cleanupResources before rowSorter is initialized in createSorter.
*/
override protected[sql] def cleanupResources(): Unit = {
if (rowSorter.get() != null) {
rowSorter.get().cleanupResources()
if (rowSorter != null) {
rowSorter.cleanupResources()
}
super.cleanupResources()
}
Expand Down