Skip to content

Commit 2970087

Browse files
sharifhsnclaude
andcommitted
PERF: Numba union-find for spatio-temporal cluster labeling
Add a JIT-compiled union-find kernel (_st_fused_ccl) that replaces the Python BFS-based _get_clusters_st for spatio-temporal adjacency when Numba is available. The kernel performs single-pass connected component labeling with path compression and union-by-rank. Bundle the tightly-coupled changes required for the kernel: - _setup_adjacency returns pre-computed CSR arrays (indptr, indices) to avoid redundant rebuilds in the permutation loop - _sums_only parameter skips cluster list construction when only aggregate sums are needed (permutation inner loop), using np.bincount instead - _csr_data parameter threads CSR arrays through the call chain These pieces are bundled because _sums_only only takes effect inside the `if has_numba:` block, and CSR data is only consumed by the Numba kernel. Splitting them would create dead code paths. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent aeffa50 commit 2970087

1 file changed

Lines changed: 226 additions & 11 deletions

File tree

mne/stats/cluster_level.py

Lines changed: 226 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,117 @@ def _sum_cluster_data(data, tstep):
114114
return np.sign(data) * np.logical_not(data == 0) * tstep
115115

116116

117+
@jit()
118+
def _st_fused_ccl(
119+
active_idx, n_active, flat_to_active, adj_indptr, adj_indices, n_src, max_step
120+
):
121+
"""Spatio-temporal union-find for neighbor-list adjacency.
122+
123+
Single-pass compiled union-find over spatial neighbors (from CSR
124+
adjacency) and temporal neighbors (same vertex at adjacent time steps).
125+
126+
Parameters
127+
----------
128+
active_idx : ndarray of intp
129+
Flat indices of active (supra-threshold) vertices.
130+
n_active : int
131+
Number of active vertices (len(active_idx)).
132+
flat_to_active : ndarray of intp, shape (n_total,)
133+
Pre-allocated lookup buffer, initialized to -1.
134+
adj_indptr : ndarray of intp, shape (n_src + 1,)
135+
CSR indptr for spatial adjacency.
136+
adj_indices : ndarray of intp
137+
CSR indices for spatial adjacency.
138+
n_src : int
139+
Number of spatial vertices.
140+
max_step : int
141+
Maximum temporal step for adjacency.
142+
143+
Returns
144+
-------
145+
components : ndarray of intp
146+
Component labels (0..n_components-1) for each active vertex.
147+
"""
148+
# Phase 1: Build flat→active mapping (O(n_active) only)
149+
for i in range(n_active):
150+
flat_to_active[active_idx[i]] = i
151+
152+
# Phase 2: Union-find over spatial + temporal edges
153+
parent = np.arange(n_active)
154+
rank = np.zeros(n_active, dtype=np.int32)
155+
156+
for a_pos in range(n_active):
157+
flat_i = active_idx[a_pos]
158+
t_i = flat_i // n_src
159+
s_i = flat_i - t_i * n_src
160+
161+
# Spatial neighbors at the same time point
162+
for j_ptr in range(adj_indptr[s_i], adj_indptr[s_i + 1]):
163+
s_j = adj_indices[j_ptr]
164+
flat_j = t_i * n_src + s_j
165+
b_pos = flat_to_active[flat_j]
166+
if b_pos >= 0:
167+
ra = a_pos
168+
while parent[ra] != ra:
169+
parent[ra] = parent[parent[ra]]
170+
ra = parent[ra]
171+
rb = b_pos
172+
while parent[rb] != rb:
173+
parent[rb] = parent[parent[rb]]
174+
rb = parent[rb]
175+
if ra != rb:
176+
if rank[ra] < rank[rb]:
177+
parent[ra] = rb
178+
elif rank[ra] > rank[rb]:
179+
parent[rb] = ra
180+
else:
181+
parent[rb] = ra
182+
rank[ra] += 1
183+
184+
# Temporal neighbors: same spatial vertex at previous time steps
185+
for step in range(1, max_step + 1):
186+
if t_i >= step:
187+
flat_j = (t_i - step) * n_src + s_i
188+
b_pos = flat_to_active[flat_j]
189+
if b_pos >= 0:
190+
ra = a_pos
191+
while parent[ra] != ra:
192+
parent[ra] = parent[parent[ra]]
193+
ra = parent[ra]
194+
rb = b_pos
195+
while parent[rb] != rb:
196+
parent[rb] = parent[parent[rb]]
197+
rb = parent[rb]
198+
if ra != rb:
199+
if rank[ra] < rank[rb]:
200+
parent[ra] = rb
201+
elif rank[ra] > rank[rb]:
202+
parent[rb] = ra
203+
else:
204+
parent[rb] = ra
205+
rank[ra] += 1
206+
207+
# Phase 3: Final path compression + relabel to 0..n_components-1
208+
label_map = -np.ones(n_active, dtype=np.intp)
209+
next_label = np.intp(0)
210+
components = np.empty(n_active, dtype=np.intp)
211+
for i in range(n_active):
212+
a = i
213+
while parent[a] != a:
214+
a = parent[a]
215+
parent[i] = a
216+
if label_map[a] == -1:
217+
label_map[a] = next_label
218+
next_label += 1
219+
components[i] = label_map[a]
220+
221+
# Phase 4: Clean up flat_to_active for next call (O(n_active) only)
222+
for i in range(n_active):
223+
flat_to_active[active_idx[i]] = -1
224+
225+
return components
226+
227+
117228
def _get_clusters_spatial(s, neighbors):
118229
"""Form spatial clusters using neighbor lists.
119230
@@ -330,6 +441,8 @@ def _find_clusters(
330441
partitions=None,
331442
t_power=1,
332443
show_info=False,
444+
_sums_only=False,
445+
_csr_data=None,
333446
):
334447
"""Find all clusters which are above/below a certain threshold.
335448
@@ -461,7 +574,15 @@ def _find_clusters(
461574
for x_in in x_ins:
462575
if np.any(x_in):
463576
out = _find_clusters_1dir_parts(
464-
x, x_in, adjacency, max_step, partitions, t_power, ndimage
577+
x,
578+
x_in,
579+
adjacency,
580+
max_step,
581+
partitions,
582+
t_power,
583+
ndimage,
584+
_sums_only=_sums_only and not tfce,
585+
_csr_data=_csr_data,
465586
)
466587
clusters += out[0]
467588
sums.append(out[1])
@@ -494,27 +615,60 @@ def _find_clusters(
494615

495616

496617
def _find_clusters_1dir_parts(
497-
x, x_in, adjacency, max_step, partitions, t_power, ndimage
618+
x,
619+
x_in,
620+
adjacency,
621+
max_step,
622+
partitions,
623+
t_power,
624+
ndimage,
625+
_sums_only=False,
626+
_csr_data=None,
498627
):
499628
"""Deal with partitions, and pass the work to _find_clusters_1dir."""
500629
if partitions is None:
501630
clusters, sums = _find_clusters_1dir(
502-
x, x_in, adjacency, max_step, t_power, ndimage
631+
x,
632+
x_in,
633+
adjacency,
634+
max_step,
635+
t_power,
636+
ndimage,
637+
_sums_only,
638+
_csr_data=_csr_data,
503639
)
504640
else:
505641
# cluster each partition separately
506642
clusters = list()
507643
sums = list()
508644
for p in range(np.max(partitions) + 1):
509645
x_i = np.logical_and(x_in, partitions == p)
510-
out = _find_clusters_1dir(x, x_i, adjacency, max_step, t_power, ndimage)
646+
out = _find_clusters_1dir(
647+
x,
648+
x_i,
649+
adjacency,
650+
max_step,
651+
t_power,
652+
ndimage,
653+
_sums_only,
654+
_csr_data=_csr_data,
655+
)
511656
clusters += out[0]
512657
sums.append(out[1])
513658
sums = np.concatenate(sums)
514659
return clusters, sums
515660

516661

517-
def _find_clusters_1dir(x, x_in, adjacency, max_step, t_power, ndimage):
662+
def _find_clusters_1dir(
663+
x,
664+
x_in,
665+
adjacency,
666+
max_step,
667+
t_power,
668+
ndimage,
669+
_sums_only=False,
670+
_csr_data=None,
671+
):
518672
"""Actually call the clustering algorithm."""
519673
if adjacency is None:
520674
labels, n_labels = ndimage.label(x_in)
@@ -554,7 +708,51 @@ def _find_clusters_1dir(x, x_in, adjacency, max_step, t_power, ndimage):
554708
if sparse.issparse(adjacency) or adjacency is False:
555709
clusters = _get_components(x_in, adjacency)
556710
elif isinstance(adjacency, list): # use temporal adjacency
557-
clusters = _get_clusters_st(x_in, adjacency, max_step)
711+
if has_numba:
712+
# Numba union-find instead of Python BFS
713+
if _csr_data is not None:
714+
_indptr, _indices, _n_src = _csr_data
715+
else:
716+
_n_src = len(adjacency)
717+
_lengths = np.array([len(a) for a in adjacency])
718+
_indptr = np.zeros(_n_src + 1, dtype=np.intp)
719+
np.cumsum(_lengths, out=_indptr[1:])
720+
_indices = np.concatenate(adjacency).astype(np.intp)
721+
active_idx = np.where(x_in)[0].astype(np.intp)
722+
n_active = len(active_idx)
723+
if n_active == 0:
724+
if _sums_only:
725+
return [], np.atleast_1d(np.array([]))
726+
clusters = []
727+
else:
728+
_flat_map = -np.ones(len(x_in), dtype=np.intp)
729+
components = _st_fused_ccl(
730+
active_idx,
731+
n_active,
732+
_flat_map,
733+
_indptr,
734+
_indices,
735+
_n_src,
736+
max_step,
737+
)
738+
if _sums_only:
739+
if t_power == 1:
740+
sums = np.bincount(components, weights=x[active_idx])
741+
else:
742+
vals = (
743+
np.sign(x[active_idx])
744+
* np.abs(x[active_idx]) ** t_power
745+
)
746+
sums = np.bincount(components, weights=vals)
747+
return [], np.atleast_1d(sums)
748+
# Reconstruct cluster index arrays from component labels
749+
order = np.argsort(components, kind="stable")
750+
counts = np.bincount(components)
751+
splits = np.cumsum(counts[:-1])
752+
global_order = active_idx[order]
753+
clusters = list(np.split(global_order, splits))
754+
else:
755+
clusters = _get_clusters_st(x_in, adjacency, max_step)
558756
else:
559757
raise TypeError(
560758
f"adjacency must be a sparse array or list, got {type(adjacency)}"
@@ -637,6 +835,7 @@ def _setup_adjacency(adjacency, n_tests, n_times):
637835
)
638836
if adjacency.shape[0] == n_tests: # use global algorithm
639837
adjacency = adjacency.tocoo()
838+
return adjacency, None
640839
else: # use temporal adjacency algorithm
641840
got_times, mod = divmod(n_tests, adjacency.shape[0])
642841
if got_times != n_times or mod != 0:
@@ -648,12 +847,19 @@ def _setup_adjacency(adjacency, n_tests, n_times):
648847
"vertices can be excluded during forward computation"
649848
)
650849
# we claim to only use upper triangular part... not true here
651-
adjacency = (adjacency + adjacency.transpose()).tocsr()
850+
adjacency_csr = (adjacency + adjacency.transpose()).tocsr()
851+
n_src = adjacency_csr.shape[0]
852+
# Pre-compute CSR arrays to avoid redundant rebuilds in inner loops.
853+
csr_data = (
854+
adjacency_csr.indptr.astype(np.intp),
855+
adjacency_csr.indices.astype(np.intp),
856+
n_src,
857+
)
652858
adjacency = [
653-
adjacency.indices[adjacency.indptr[i] : adjacency.indptr[i + 1]]
654-
for i in range(len(adjacency.indptr) - 1)
859+
adjacency_csr.indices[adjacency_csr.indptr[i] : adjacency_csr.indptr[i + 1]]
860+
for i in range(n_src)
655861
]
656-
return adjacency
862+
return adjacency, csr_data
657863

658864

659865
def _do_permutations(
@@ -671,6 +877,7 @@ def _do_permutations(
671877
sample_shape,
672878
buffer_size,
673879
progress_bar,
880+
_csr_data=None,
674881
):
675882
n_samp, n_vars = X_full.shape
676883

@@ -725,6 +932,8 @@ def _do_permutations(
725932
partitions=partitions,
726933
include=include,
727934
t_power=t_power,
935+
_sums_only=True,
936+
_csr_data=_csr_data,
728937
)
729938
perm_clusters_sums = out[1]
730939

@@ -753,6 +962,7 @@ def _do_1samp_permutations(
753962
sample_shape,
754963
buffer_size,
755964
progress_bar,
965+
_csr_data=None,
756966
):
757967
n_samp, n_vars = X.shape
758968
assert slices is None # should be None for the 1 sample case
@@ -831,6 +1041,8 @@ def _do_1samp_permutations(
8311041
partitions=partitions,
8321042
include=include,
8331043
t_power=t_power,
1044+
_sums_only=True,
1045+
_csr_data=_csr_data,
8341046
)
8351047
perm_clusters_sums = out[1]
8361048
if len(perm_clusters_sums) > 0:
@@ -975,8 +1187,9 @@ def _permutation_cluster_test(
9751187
X = [np.reshape(x, (x.shape[0], -1)) for x in X]
9761188
n_tests = X[0].shape[1]
9771189

1190+
_adj_csr_data = None # Pre-computed CSR arrays for temporal adjacency
9781191
if adjacency is not None and adjacency is not False:
979-
adjacency = _setup_adjacency(adjacency, n_tests, n_times)
1192+
adjacency, _adj_csr_data = _setup_adjacency(adjacency, n_tests, n_times)
9801193

9811194
if (exclude is not None) and not exclude.size == n_tests:
9821195
raise ValueError("exclude must be the same shape as X[0]")
@@ -1035,6 +1248,7 @@ def _permutation_cluster_test(
10351248
partitions=partitions,
10361249
t_power=t_power,
10371250
show_info=True,
1251+
_csr_data=_adj_csr_data,
10381252
)
10391253
clusters, cluster_stats = out
10401254

@@ -1128,6 +1342,7 @@ def _permutation_cluster_test(
11281342
sample_shape,
11291343
buffer_size,
11301344
progress_bar.subset(idx),
1345+
_adj_csr_data,
11311346
)
11321347
for idx, order in split_list(orders, n_jobs, idx=True)
11331348
)

0 commit comments

Comments
 (0)