[SPARK-56315][SQL] Pre-aggregate before Expand to reduce data amplification for multiple COUNT(DISTINCT)#55130
[SPARK-56315][SQL] Pre-aggregate before Expand to reduce data amplification for multiple COUNT(DISTINCT)#55130LuciferYang wants to merge 6 commits intoapache:masterfrom
Expand to reduce data amplification for multiple COUNT(DISTINCT)#55130Conversation
…Benchmark (JDK 17, Scala 2.13, split 1 of 1)
…Benchmark (JDK 21, Scala 2.13, split 1 of 1)
…Benchmark (JDK 25, Scala 2.13, split 1 of 1)
| import org.apache.spark.sql.catalyst.rules._ | ||
| import org.apache.spark.sql.internal.SQLConf | ||
|
|
||
| class OptimizeExpandSuite extends PlanTest { |
There was a problem hiding this comment.
Consider a test for COUNT(DISTINCT col1 + col2). The pre-aggregate groups by leaf attributes (col1, col2) not the expression, so dedup is less effective but still correct. Worth documenting.
There was a problem hiding this comment.
Thanks for the suggestion! Added tests in the latest commit:
Plan-level test (OptimizeExpandSuite): verifies that for COUNT(DISTINCT col1 + col2), the pre-aggregate groups by leaf attributes (key, col1, col2, col3) rather than (key, col1+col2, col3). The comment documents the trade-off — dedup is less effective (more groups than strictly necessary) but correctness is preserved.
Correctness test (OptimizeExpandQuerySuite): uses coprime moduli (col1 = id % 7, col2 = id % 11) to generate data where 77 unique (col1, col2) pairs map to only 17 unique col1 + col2 values, ensuring the "less effective dedup" scenario is actually exercised. Verifies optimized results match the non-optimized baseline.
Address review feedback: document and test that the pre-aggregate groups by leaf attributes (col1, col2) rather than the expression (col1 + col2), making dedup less effective but still correct.
What changes were proposed in this pull request?
Add an optimizer rule
OptimizeExpandthat inserts a de-duplication aggregate before the Expand operator produced byRewriteDistinctAggregates.Queries with multiple
COUNT(DISTINCT)on different columns are rewritten byRewriteDistinctAggregatesinto a plan that duplicates each input row N times via an Expand operator (where N = number of distinct groups). This data amplification becomes a bottleneck as N grows.This rule transforms the plan from:
to:
The inserted pre-aggregation eliminates duplicate rows before the Expand applies its N-fold amplification, significantly reducing shuffled data volume.
The optimization is controlled by
spark.sql.optimizer.optimizeExpandRatio(default -1 = disabled, minimum 2). It only applies when:SUM(value))Why are the changes needed?
For queries like
SELECT key, COUNT(DISTINCT a), COUNT(DISTINCT b), COUNT(DISTINCT c) FROM t GROUP BY key, the Expand operator duplicates each row 3x. With 6 distinct aggregates, it's 6x. The pre-aggregation step eliminates redundant rows before this amplification, converting O(N * input_rows) to O(N * distinct_groups).Benchmark results
Pure distinct aggregates (6 COUNT(DISTINCT), JDK 17):
Scaling with number of distinct aggregates (JDK 17):
Mixed aggregates (with SUM): no regression — baseline and optimized are within noise.
Results are consistent across JDK 17, 21, and 25.
Does this PR introduce any user-facing change?
No. The optimization is disabled by default (
spark.sql.optimizer.optimizeExpandRatio = -1). When enabled, query results are unchanged — only execution performance is affected.How was this patch tested?
Add new tests:
OptimizeExpandSuite(plan-level rule tests)OptimizeExpandQuerySuite(result correctness)ExpandBenchmarkwith results for JDK 17, 21, and 25Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code