@@ -114,6 +114,117 @@ def _sum_cluster_data(data, tstep):
114114 return np .sign (data ) * np .logical_not (data == 0 ) * tstep
115115
116116
117+ @jit ()
118+ def _st_fused_ccl (
119+ active_idx , n_active , flat_to_active , adj_indptr , adj_indices , n_src , max_step
120+ ):
121+ """Spatio-temporal union-find for neighbor-list adjacency.
122+
123+ Single-pass compiled union-find over spatial neighbors (from CSR
124+ adjacency) and temporal neighbors (same vertex at adjacent time steps).
125+
126+ Parameters
127+ ----------
128+ active_idx : ndarray of intp
129+ Flat indices of active (supra-threshold) vertices.
130+ n_active : int
131+ Number of active vertices (len(active_idx)).
132+ flat_to_active : ndarray of intp, shape (n_total,)
133+ Pre-allocated lookup buffer, initialized to -1.
134+ adj_indptr : ndarray of intp, shape (n_src + 1,)
135+ CSR indptr for spatial adjacency.
136+ adj_indices : ndarray of intp
137+ CSR indices for spatial adjacency.
138+ n_src : int
139+ Number of spatial vertices.
140+ max_step : int
141+ Maximum temporal step for adjacency.
142+
143+ Returns
144+ -------
145+ components : ndarray of intp
146+ Component labels (0..n_components-1) for each active vertex.
147+ """
148+ # Phase 1: Build flat→active mapping (O(n_active) only)
149+ for i in range (n_active ):
150+ flat_to_active [active_idx [i ]] = i
151+
152+ # Phase 2: Union-find over spatial + temporal edges
153+ parent = np .arange (n_active )
154+ rank = np .zeros (n_active , dtype = np .int32 )
155+
156+ for a_pos in range (n_active ):
157+ flat_i = active_idx [a_pos ]
158+ t_i = flat_i // n_src
159+ s_i = flat_i - t_i * n_src
160+
161+ # Spatial neighbors at the same time point
162+ for j_ptr in range (adj_indptr [s_i ], adj_indptr [s_i + 1 ]):
163+ s_j = adj_indices [j_ptr ]
164+ flat_j = t_i * n_src + s_j
165+ b_pos = flat_to_active [flat_j ]
166+ if b_pos >= 0 :
167+ ra = a_pos
168+ while parent [ra ] != ra :
169+ parent [ra ] = parent [parent [ra ]]
170+ ra = parent [ra ]
171+ rb = b_pos
172+ while parent [rb ] != rb :
173+ parent [rb ] = parent [parent [rb ]]
174+ rb = parent [rb ]
175+ if ra != rb :
176+ if rank [ra ] < rank [rb ]:
177+ parent [ra ] = rb
178+ elif rank [ra ] > rank [rb ]:
179+ parent [rb ] = ra
180+ else :
181+ parent [rb ] = ra
182+ rank [ra ] += 1
183+
184+ # Temporal neighbors: same spatial vertex at previous time steps
185+ for step in range (1 , max_step + 1 ):
186+ if t_i >= step :
187+ flat_j = (t_i - step ) * n_src + s_i
188+ b_pos = flat_to_active [flat_j ]
189+ if b_pos >= 0 :
190+ ra = a_pos
191+ while parent [ra ] != ra :
192+ parent [ra ] = parent [parent [ra ]]
193+ ra = parent [ra ]
194+ rb = b_pos
195+ while parent [rb ] != rb :
196+ parent [rb ] = parent [parent [rb ]]
197+ rb = parent [rb ]
198+ if ra != rb :
199+ if rank [ra ] < rank [rb ]:
200+ parent [ra ] = rb
201+ elif rank [ra ] > rank [rb ]:
202+ parent [rb ] = ra
203+ else :
204+ parent [rb ] = ra
205+ rank [ra ] += 1
206+
207+ # Phase 3: Final path compression + relabel to 0..n_components-1
208+ label_map = - np .ones (n_active , dtype = np .intp )
209+ next_label = np .intp (0 )
210+ components = np .empty (n_active , dtype = np .intp )
211+ for i in range (n_active ):
212+ a = i
213+ while parent [a ] != a :
214+ a = parent [a ]
215+ parent [i ] = a
216+ if label_map [a ] == - 1 :
217+ label_map [a ] = next_label
218+ next_label += 1
219+ components [i ] = label_map [a ]
220+
221+ # Phase 4: Clean up flat_to_active for next call (O(n_active) only)
222+ for i in range (n_active ):
223+ flat_to_active [active_idx [i ]] = - 1
224+
225+ return components
226+
227+
117228def _get_clusters_spatial (s , neighbors ):
118229 """Form spatial clusters using neighbor lists.
119230
@@ -330,6 +441,8 @@ def _find_clusters(
330441 partitions = None ,
331442 t_power = 1 ,
332443 show_info = False ,
444+ _sums_only = False ,
445+ _csr_data = None ,
333446):
334447 """Find all clusters which are above/below a certain threshold.
335448
@@ -461,7 +574,15 @@ def _find_clusters(
461574 for x_in in x_ins :
462575 if np .any (x_in ):
463576 out = _find_clusters_1dir_parts (
464- x , x_in , adjacency , max_step , partitions , t_power , ndimage
577+ x ,
578+ x_in ,
579+ adjacency ,
580+ max_step ,
581+ partitions ,
582+ t_power ,
583+ ndimage ,
584+ _sums_only = _sums_only and not tfce ,
585+ _csr_data = _csr_data ,
465586 )
466587 clusters += out [0 ]
467588 sums .append (out [1 ])
@@ -494,27 +615,60 @@ def _find_clusters(
494615
495616
496617def _find_clusters_1dir_parts (
497- x , x_in , adjacency , max_step , partitions , t_power , ndimage
618+ x ,
619+ x_in ,
620+ adjacency ,
621+ max_step ,
622+ partitions ,
623+ t_power ,
624+ ndimage ,
625+ _sums_only = False ,
626+ _csr_data = None ,
498627):
499628 """Deal with partitions, and pass the work to _find_clusters_1dir."""
500629 if partitions is None :
501630 clusters , sums = _find_clusters_1dir (
502- x , x_in , adjacency , max_step , t_power , ndimage
631+ x ,
632+ x_in ,
633+ adjacency ,
634+ max_step ,
635+ t_power ,
636+ ndimage ,
637+ _sums_only ,
638+ _csr_data = _csr_data ,
503639 )
504640 else :
505641 # cluster each partition separately
506642 clusters = list ()
507643 sums = list ()
508644 for p in range (np .max (partitions ) + 1 ):
509645 x_i = np .logical_and (x_in , partitions == p )
510- out = _find_clusters_1dir (x , x_i , adjacency , max_step , t_power , ndimage )
646+ out = _find_clusters_1dir (
647+ x ,
648+ x_i ,
649+ adjacency ,
650+ max_step ,
651+ t_power ,
652+ ndimage ,
653+ _sums_only ,
654+ _csr_data = _csr_data ,
655+ )
511656 clusters += out [0 ]
512657 sums .append (out [1 ])
513658 sums = np .concatenate (sums )
514659 return clusters , sums
515660
516661
517- def _find_clusters_1dir (x , x_in , adjacency , max_step , t_power , ndimage ):
662+ def _find_clusters_1dir (
663+ x ,
664+ x_in ,
665+ adjacency ,
666+ max_step ,
667+ t_power ,
668+ ndimage ,
669+ _sums_only = False ,
670+ _csr_data = None ,
671+ ):
518672 """Actually call the clustering algorithm."""
519673 if adjacency is None :
520674 labels , n_labels = ndimage .label (x_in )
@@ -554,7 +708,51 @@ def _find_clusters_1dir(x, x_in, adjacency, max_step, t_power, ndimage):
554708 if sparse .issparse (adjacency ) or adjacency is False :
555709 clusters = _get_components (x_in , adjacency )
556710 elif isinstance (adjacency , list ): # use temporal adjacency
557- clusters = _get_clusters_st (x_in , adjacency , max_step )
711+ if has_numba :
712+ # Numba union-find instead of Python BFS
713+ if _csr_data is not None :
714+ _indptr , _indices , _n_src = _csr_data
715+ else :
716+ _n_src = len (adjacency )
717+ _lengths = np .array ([len (a ) for a in adjacency ])
718+ _indptr = np .zeros (_n_src + 1 , dtype = np .intp )
719+ np .cumsum (_lengths , out = _indptr [1 :])
720+ _indices = np .concatenate (adjacency ).astype (np .intp )
721+ active_idx = np .where (x_in )[0 ].astype (np .intp )
722+ n_active = len (active_idx )
723+ if n_active == 0 :
724+ if _sums_only :
725+ return [], np .atleast_1d (np .array ([]))
726+ clusters = []
727+ else :
728+ _flat_map = - np .ones (len (x_in ), dtype = np .intp )
729+ components = _st_fused_ccl (
730+ active_idx ,
731+ n_active ,
732+ _flat_map ,
733+ _indptr ,
734+ _indices ,
735+ _n_src ,
736+ max_step ,
737+ )
738+ if _sums_only :
739+ if t_power == 1 :
740+ sums = np .bincount (components , weights = x [active_idx ])
741+ else :
742+ vals = (
743+ np .sign (x [active_idx ])
744+ * np .abs (x [active_idx ]) ** t_power
745+ )
746+ sums = np .bincount (components , weights = vals )
747+ return [], np .atleast_1d (sums )
748+ # Reconstruct cluster index arrays from component labels
749+ order = np .argsort (components , kind = "stable" )
750+ counts = np .bincount (components )
751+ splits = np .cumsum (counts [:- 1 ])
752+ global_order = active_idx [order ]
753+ clusters = list (np .split (global_order , splits ))
754+ else :
755+ clusters = _get_clusters_st (x_in , adjacency , max_step )
558756 else :
559757 raise TypeError (
560758 f"adjacency must be a sparse array or list, got { type (adjacency )} "
@@ -637,6 +835,7 @@ def _setup_adjacency(adjacency, n_tests, n_times):
637835 )
638836 if adjacency .shape [0 ] == n_tests : # use global algorithm
639837 adjacency = adjacency .tocoo ()
838+ return adjacency , None
640839 else : # use temporal adjacency algorithm
641840 got_times , mod = divmod (n_tests , adjacency .shape [0 ])
642841 if got_times != n_times or mod != 0 :
@@ -648,12 +847,19 @@ def _setup_adjacency(adjacency, n_tests, n_times):
648847 "vertices can be excluded during forward computation"
649848 )
650849 # we claim to only use upper triangular part... not true here
651- adjacency = (adjacency + adjacency .transpose ()).tocsr ()
850+ adjacency_csr = (adjacency + adjacency .transpose ()).tocsr ()
851+ n_src = adjacency_csr .shape [0 ]
852+ # Pre-compute CSR arrays to avoid redundant rebuilds in inner loops.
853+ csr_data = (
854+ adjacency_csr .indptr .astype (np .intp ),
855+ adjacency_csr .indices .astype (np .intp ),
856+ n_src ,
857+ )
652858 adjacency = [
653- adjacency .indices [adjacency .indptr [i ] : adjacency .indptr [i + 1 ]]
654- for i in range (len ( adjacency . indptr ) - 1 )
859+ adjacency_csr .indices [adjacency_csr .indptr [i ] : adjacency_csr .indptr [i + 1 ]]
860+ for i in range (n_src )
655861 ]
656- return adjacency
862+ return adjacency , csr_data
657863
658864
659865def _do_permutations (
@@ -671,6 +877,7 @@ def _do_permutations(
671877 sample_shape ,
672878 buffer_size ,
673879 progress_bar ,
880+ _csr_data = None ,
674881):
675882 n_samp , n_vars = X_full .shape
676883
@@ -725,6 +932,8 @@ def _do_permutations(
725932 partitions = partitions ,
726933 include = include ,
727934 t_power = t_power ,
935+ _sums_only = True ,
936+ _csr_data = _csr_data ,
728937 )
729938 perm_clusters_sums = out [1 ]
730939
@@ -753,6 +962,7 @@ def _do_1samp_permutations(
753962 sample_shape ,
754963 buffer_size ,
755964 progress_bar ,
965+ _csr_data = None ,
756966):
757967 n_samp , n_vars = X .shape
758968 assert slices is None # should be None for the 1 sample case
@@ -831,6 +1041,8 @@ def _do_1samp_permutations(
8311041 partitions = partitions ,
8321042 include = include ,
8331043 t_power = t_power ,
1044+ _sums_only = True ,
1045+ _csr_data = _csr_data ,
8341046 )
8351047 perm_clusters_sums = out [1 ]
8361048 if len (perm_clusters_sums ) > 0 :
@@ -975,8 +1187,9 @@ def _permutation_cluster_test(
9751187 X = [np .reshape (x , (x .shape [0 ], - 1 )) for x in X ]
9761188 n_tests = X [0 ].shape [1 ]
9771189
1190+ _adj_csr_data = None # Pre-computed CSR arrays for temporal adjacency
9781191 if adjacency is not None and adjacency is not False :
979- adjacency = _setup_adjacency (adjacency , n_tests , n_times )
1192+ adjacency , _adj_csr_data = _setup_adjacency (adjacency , n_tests , n_times )
9801193
9811194 if (exclude is not None ) and not exclude .size == n_tests :
9821195 raise ValueError ("exclude must be the same shape as X[0]" )
@@ -1035,6 +1248,7 @@ def _permutation_cluster_test(
10351248 partitions = partitions ,
10361249 t_power = t_power ,
10371250 show_info = True ,
1251+ _csr_data = _adj_csr_data ,
10381252 )
10391253 clusters , cluster_stats = out
10401254
@@ -1128,6 +1342,7 @@ def _permutation_cluster_test(
11281342 sample_shape ,
11291343 buffer_size ,
11301344 progress_bar .subset (idx ),
1345+ _adj_csr_data ,
11311346 )
11321347 for idx , order in split_list (orders , n_jobs , idx = True )
11331348 )
0 commit comments