From ad8f2bc328535646d51483c9e553b1313d1d5a64 Mon Sep 17 00:00:00 2001 From: tayheau Date: Tue, 9 Dec 2025 10:32:22 +0100 Subject: [PATCH 01/14] first commit to create branch --- src/spikeinterface/core/sorting_tools.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 90c7e18a99..15337671f5 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -185,8 +185,13 @@ def random_spikes_selection( random_spikes_indices: np.array Selected spike indices coresponding to the sorting spike vector. """ + rng_methods = ("uniform", "percentage") - if method == "uniform": + if method == "all": + spikes = sorting.to_spike_vector() + random_spikes_indices = np.arange(spikes.size) + + elif method in rng_methods: rng = np.random.default_rng(seed=seed) spikes = sorting.to_spike_vector(concatenated=False) @@ -211,17 +216,20 @@ def random_spikes_selection( inds_in_seg_abs = inds_in_seg + cum_sizes[segment_index] all_unit_indices.append(inds_in_seg_abs) all_unit_indices = np.concatenate(all_unit_indices) + + if method == "uniform": + rng_size = min(max_spikes_per_unit, all_unit_indices.size) + elif method == "percentage": + rng_size = min(max_spikes_per_unit, all_unit_indices.size * percentage) + selected_unit_indices = rng.choice( - all_unit_indices, size=min(max_spikes_per_unit, all_unit_indices.size), replace=False, shuffle=False + all_unit_indices, size=rng_size, replace=False, shuffle=False ) random_spikes_indices.append(selected_unit_indices) random_spikes_indices = np.concatenate(random_spikes_indices) random_spikes_indices = np.sort(random_spikes_indices) - elif method == "all": - spikes = sorting.to_spike_vector() - random_spikes_indices = np.arange(spikes.size) else: raise ValueError(f"random_spikes_selection(): method must be 'all' or 'uniform'") From e6c9c2b1d459b02d099bf08f2881812e7ed4a965 Mon Sep 17 00:00:00 2001 From: tayheau Date: Tue, 9 Dec 2025 17:54:54 +0100 Subject: [PATCH 02/14] temporal bin, rate cap and percentage sampling --- src/spikeinterface/core/sorting_tools.py | 93 ++++++++++++++++--- .../waveforms/temporal_pca.py | 4 +- src/spikeinterface/widgets/utils.py | 8 +- 3 files changed, 90 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 15337671f5..c41d0102d3 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -3,8 +3,11 @@ import warnings import importlib.util +from typing import Literal + import numpy as np +from spikeinterface.widgets.utils import get_segment_durations from spikeinterface.core.base import BaseExtractor from spikeinterface.core.basesorting import BaseSorting from spikeinterface.core.numpyextractors import NumpySorting @@ -148,12 +151,16 @@ def vector_to_list_of_spiketrain_numba(sample_indices, unit_indices, num_units): return vector_to_list_of_spiketrain_numba -# TODO later : implement other method like "maximum_rate", "by_percent", ... +# stratified sampling (isi / amplitude / pca distance ? ) def random_spikes_selection( sorting: BaseSorting, - num_samples: int | None = None, - method: str = "uniform", + num_samples: list[int] | None = None, + method: Literal["uniform", "all", "percentage", "maximum_rate", "temporal_bins"] = "uniform", max_spikes_per_unit: int = 500, + percentage: float | None = None, + maximum_rate: float | None = None, + bin_size_s: float | None = None, + k_per_bin: int | None = None, margin_size: int | None = None, seed: int | None = None, ): @@ -167,14 +174,22 @@ def random_spikes_selection( ---------- sorting: BaseSorting The sorting object - num_samples: list of int + num_samples: list[int] | None, default: None The number of samples per segment. Can be retrieved from recording with num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] - method: "uniform" | "all", default: "uniform" - The method to use. Only "uniform" is implemented for now + method: "uniform" | "percentage" | "maximum_rate" | "all", default: "uniform" + The method to use. max_spikes_per_unit: int, default: 500 - The number of spikes per units + The maximum number of spikes per units + percentage: float | None, default: None + In case of `percentage` method. The proportion of spikes per units. + maximum_rate: float | None, default: None + In case of `maximum_rate` method. The cap rate per units. + bin_size_s: float | None, default: None + In case of `temporal_bins` method. The duration of a temporal bin. + k_per_bin: int | None, default: None + In case of `temporal_bins` method. Maximum number of spikes per bins. margin_size: None | int, default: None A margin on each border of segments to avoid border spikes seed: None | int, default: None @@ -185,7 +200,7 @@ def random_spikes_selection( random_spikes_indices: np.array Selected spike indices coresponding to the sorting spike vector. """ - rng_methods = ("uniform", "percentage") + rng_methods = ("uniform", "percentage", "maximum_rate", "temporal_bins") if method == "all": spikes = sorting.to_spike_vector() @@ -194,6 +209,8 @@ def random_spikes_selection( elif method in rng_methods: rng = np.random.default_rng(seed=seed) + # since un concatenated + # spikes = [ [ (sample_index, unit_index, segment_index), (), ... ], [ (), ... ]] spikes = sorting.to_spike_vector(concatenated=False) cum_sizes = np.cumsum([0] + [s.size for s in spikes]) @@ -203,10 +220,15 @@ def random_spikes_selection( random_spikes_indices = [] for unit_index, unit_id in enumerate(sorting.unit_ids): all_unit_indices = [] + all_unit_trains = [] for segment_index in range(sorting.get_num_segments()): - # this is local index + # this is local segment index + trains_in_seg = spike_trains[segment_index][unit_id] inds_in_seg = spike_indices[segment_index][unit_id] if margin_size is not None: + if num_samples is None: + raise ValueError("num_samples must be provided when margin_size is used") + local_spikes = spikes[segment_index][inds_in_seg] mask = (local_spikes["sample_index"] >= margin_size) & ( local_spikes["sample_index"] < (num_samples[segment_index] - margin_size) @@ -219,8 +241,56 @@ def random_spikes_selection( if method == "uniform": rng_size = min(max_spikes_per_unit, all_unit_indices.size) + elif method == "percentage": - rng_size = min(max_spikes_per_unit, all_unit_indices.size * percentage) + if percentage is None or not (0 < percentage <= 1) : + raise ValueError(f"percentage must be in the interval (0, 1]") + + rng_size = min(max_spikes_per_unit, int(all_unit_indices.size * percentage)) + + elif method == "maximum_rate": + if maximum_rate is None: + raise ValueError(f"maximum_rate must be defined") + + t_duration = np.sum(get_segment_durations(sorting)) + rng_size = min(int(t_duration * maximum_rate), max_spikes_per_unit, all_unit_indices.size) + + elif method == "temporal_bins": + # expressed bin sampling as a dual sub sorting problem to be fully vectorized + + if None in (k_per_bin, bin_size_s): + missing = [] + if k_per_bin is None: + missing.append("k_per_bin") + if bin_size_s is None: + missing.append("bin_size_s") + print(f"the following args need to be defined when using the 'temporal bins' method : {', '.join(missing)}") + + sampling_frequency = sorting.get_sampling_frequency() + bin_size_freq = int(bin_size_s * sampling_frequency) + + unit_spikes = np.concat(spikes)[all_unit_indices] + + # local to segment so will loop and reset + bin_index = unit_spikes["sample_index"] // bin_size_freq + segment_index = unit_spikes["segment_index"] + + group_values = np.stack((segment_index, bin_index), axis = 1) + _, group_keys = np.unique(group_values, return_inverse = True, axis= 0) + + score = rng.random(all_unit_indices.size) + order = np.lexsort((score, group_keys)) + + ordered_unit_indices = all_unit_indices[order] + + group_start = np.r_[0, np.flatnonzero(np.diff(group_keys)) + 1 ] + counts = np.diff(np.r_[group_start, ordered_unit_indices.size]) + + ranks = np.arange(ordered_unit_indices.size, step=1) - np.repeat(group_start, counts) + selection_mask = ranks <= k_per_bin + selected = ordered_unit_indices[selection_mask] + random_spikes_indices.append(selected) + continue selected_unit_indices = rng.choice( all_unit_indices, size=rng_size, replace=False, shuffle=False @@ -231,11 +301,10 @@ def random_spikes_selection( random_spikes_indices = np.sort(random_spikes_indices) else: - raise ValueError(f"random_spikes_selection(): method must be 'all' or 'uniform'") + raise ValueError(f"random_spikes_selection(): method must be 'all' or any in {', '.join(rng_methods)}") return random_spikes_indices - ### MERGING ZONE ### def apply_merges_to_sorting( sorting: BaseSorting, diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index b1d3d5deaf..4720ff9098 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -96,11 +96,11 @@ def fit( model_folder_path: str, detect_peaks_params: dict, peak_selection_params: dict, - job_kwargs: dict = None, + job_kwargs: dict | None = None, ms_before: float = 1.0, ms_after: float = 1.0, whiten: bool = True, - radius_um: float = None, + radius_um: float | None = None, ) -> "IncrementalPCA": """ Train a pca model using the data in the recording object and the parameters provided. diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 50406b109e..923a950979 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -401,7 +401,7 @@ def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSor return segment_indices -def get_segment_durations(sorting: BaseSorting, segment_indices: list[int]) -> list[float]: +def get_segment_durations(sorting: BaseSorting, segment_indices: list[int] = None) -> list[float]: """ Calculate the duration of each segment in a sorting object. @@ -410,11 +410,17 @@ def get_segment_durations(sorting: BaseSorting, segment_indices: list[int]) -> l sorting : BaseSorting The sorting object containing spike data + segment_indices : list[int] | None + List of the segment indices to process. Default to None. + Returns ------- list[float] List of segment durations in seconds """ + if segment_indices is None: + segment_indices = range(sorting.get_num_segments()) + spikes = sorting.to_spike_vector() segment_boundaries = [ From bef42f1deea6c9e43c3a36e7e26c86ce283fed71 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Dec 2025 13:57:44 +0000 Subject: [PATCH 03/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sorting_tools.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index c41d0102d3..f17ae9a600 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -174,7 +174,7 @@ def random_spikes_selection( ---------- sorting: BaseSorting The sorting object - num_samples: list[int] | None, default: None + num_samples: list[int] | None, default: None The number of samples per segment. Can be retrieved from recording with num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] @@ -243,7 +243,7 @@ def random_spikes_selection( rng_size = min(max_spikes_per_unit, all_unit_indices.size) elif method == "percentage": - if percentage is None or not (0 < percentage <= 1) : + if percentage is None or not (0 < percentage <= 1): raise ValueError(f"percentage must be in the interval (0, 1]") rng_size = min(max_spikes_per_unit, int(all_unit_indices.size * percentage)) @@ -264,7 +264,9 @@ def random_spikes_selection( missing.append("k_per_bin") if bin_size_s is None: missing.append("bin_size_s") - print(f"the following args need to be defined when using the 'temporal bins' method : {', '.join(missing)}") + print( + f"the following args need to be defined when using the 'temporal bins' method : {', '.join(missing)}" + ) sampling_frequency = sorting.get_sampling_frequency() bin_size_freq = int(bin_size_s * sampling_frequency) @@ -275,15 +277,15 @@ def random_spikes_selection( bin_index = unit_spikes["sample_index"] // bin_size_freq segment_index = unit_spikes["segment_index"] - group_values = np.stack((segment_index, bin_index), axis = 1) - _, group_keys = np.unique(group_values, return_inverse = True, axis= 0) + group_values = np.stack((segment_index, bin_index), axis=1) + _, group_keys = np.unique(group_values, return_inverse=True, axis=0) score = rng.random(all_unit_indices.size) order = np.lexsort((score, group_keys)) ordered_unit_indices = all_unit_indices[order] - - group_start = np.r_[0, np.flatnonzero(np.diff(group_keys)) + 1 ] + + group_start = np.r_[0, np.flatnonzero(np.diff(group_keys)) + 1] counts = np.diff(np.r_[group_start, ordered_unit_indices.size]) ranks = np.arange(ordered_unit_indices.size, step=1) - np.repeat(group_start, counts) @@ -292,9 +294,7 @@ def random_spikes_selection( random_spikes_indices.append(selected) continue - selected_unit_indices = rng.choice( - all_unit_indices, size=rng_size, replace=False, shuffle=False - ) + selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False) random_spikes_indices.append(selected_unit_indices) random_spikes_indices = np.concatenate(random_spikes_indices) @@ -305,6 +305,7 @@ def random_spikes_selection( return random_spikes_indices + ### MERGING ZONE ### def apply_merges_to_sorting( sorting: BaseSorting, From b037441dc1819c93e6e57f853427c041d4a2fc54 Mon Sep 17 00:00:00 2001 From: tayheau Date: Thu, 18 Dec 2025 15:24:12 +0100 Subject: [PATCH 04/14] lazy loading get_segment_duration --- src/spikeinterface/core/sorting_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index f17ae9a600..4ab38bc2a9 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -7,7 +7,6 @@ import numpy as np -from spikeinterface.widgets.utils import get_segment_durations from spikeinterface.core.base import BaseExtractor from spikeinterface.core.basesorting import BaseSorting from spikeinterface.core.numpyextractors import NumpySorting @@ -207,6 +206,7 @@ def random_spikes_selection( random_spikes_indices = np.arange(spikes.size) elif method in rng_methods: + from spikeinterface.widgets.utils import get_segment_durations rng = np.random.default_rng(seed=seed) # since un concatenated From 625bbef6401bcb69e3124225564024757190143b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Dec 2025 14:24:41 +0000 Subject: [PATCH 05/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sorting_tools.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 4ab38bc2a9..0412d00a3a 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -207,6 +207,7 @@ def random_spikes_selection( elif method in rng_methods: from spikeinterface.widgets.utils import get_segment_durations + rng = np.random.default_rng(seed=seed) # since un concatenated From acfcb660b21294d74af7859b8b5f0e0049ec574e Mon Sep 17 00:00:00 2001 From: tayheau Date: Thu, 18 Dec 2025 18:24:28 +0100 Subject: [PATCH 06/14] removed unused var --- src/spikeinterface/core/sorting_tools.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 0412d00a3a..6e84d31013 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -221,10 +221,8 @@ def random_spikes_selection( random_spikes_indices = [] for unit_index, unit_id in enumerate(sorting.unit_ids): all_unit_indices = [] - all_unit_trains = [] for segment_index in range(sorting.get_num_segments()): # this is local segment index - trains_in_seg = spike_trains[segment_index][unit_id] inds_in_seg = spike_indices[segment_index][unit_id] if margin_size is not None: if num_samples is None: @@ -265,7 +263,7 @@ def random_spikes_selection( missing.append("k_per_bin") if bin_size_s is None: missing.append("bin_size_s") - print( + raise ValueError( f"the following args need to be defined when using the 'temporal bins' method : {', '.join(missing)}" ) From 2f00b1b6eccf9cdc9ddae1ba335be93229641d7b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 23 Dec 2025 14:48:43 +0100 Subject: [PATCH 07/14] small changes --- src/spikeinterface/core/sorting_tools.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 6e84d31013..e0184ff296 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -177,7 +177,7 @@ def random_spikes_selection( The number of samples per segment. Can be retrieved from recording with num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] - method: "uniform" | "percentage" | "maximum_rate" | "all", default: "uniform" + method: "uniform" | "percentage" | "maximum_rate" | "all" | "temporal_bins", default: "uniform" The method to use. max_spikes_per_unit: int, default: 500 The maximum number of spikes per units @@ -215,7 +215,7 @@ def random_spikes_selection( spikes = sorting.to_spike_vector(concatenated=False) cum_sizes = np.cumsum([0] + [s.size for s in spikes]) - # this fast when numba + # this is fast when numba is installed spike_indices = spike_vector_to_indices(spikes, sorting.unit_ids, absolute_index=False) random_spikes_indices = [] @@ -240,12 +240,14 @@ def random_spikes_selection( if method == "uniform": rng_size = min(max_spikes_per_unit, all_unit_indices.size) + selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False) elif method == "percentage": if percentage is None or not (0 < percentage <= 1): raise ValueError(f"percentage must be in the interval (0, 1]") rng_size = min(max_spikes_per_unit, int(all_unit_indices.size * percentage)) + selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False) elif method == "maximum_rate": if maximum_rate is None: @@ -253,6 +255,7 @@ def random_spikes_selection( t_duration = np.sum(get_segment_durations(sorting)) rng_size = min(int(t_duration * maximum_rate), max_spikes_per_unit, all_unit_indices.size) + selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False) elif method == "temporal_bins": # expressed bin sampling as a dual sub sorting problem to be fully vectorized @@ -289,11 +292,8 @@ def random_spikes_selection( ranks = np.arange(ordered_unit_indices.size, step=1) - np.repeat(group_start, counts) selection_mask = ranks <= k_per_bin - selected = ordered_unit_indices[selection_mask] - random_spikes_indices.append(selected) - continue - - selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False) + selected_unit_indices = ordered_unit_indices[selection_mask] + random_spikes_indices.append(selected_unit_indices) random_spikes_indices = np.concatenate(random_spikes_indices) From 5aebf6678210e08e3d5d90aa11f50f27c4c50bf3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 13:49:15 +0000 Subject: [PATCH 08/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sorting_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index e0184ff296..f02c15c2fa 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -293,7 +293,7 @@ def random_spikes_selection( ranks = np.arange(ordered_unit_indices.size, step=1) - np.repeat(group_start, counts) selection_mask = ranks <= k_per_bin selected_unit_indices = ordered_unit_indices[selection_mask] - + random_spikes_indices.append(selected_unit_indices) random_spikes_indices = np.concatenate(random_spikes_indices) From fa3bf6f0dfbc815142c16eaae11c407390cdcdaa Mon Sep 17 00:00:00 2001 From: tayheau Date: Fri, 6 Mar 2026 16:36:51 +0100 Subject: [PATCH 09/14] removed rnd selection method --- src/spikeinterface/core/sorting_tools.py | 48 ++---------------------- 1 file changed, 3 insertions(+), 45 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index a6a64f311b..17fc830f79 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -154,12 +154,10 @@ def vector_to_list_of_spiketrain_numba(sample_indices, unit_indices, num_units): def random_spikes_selection( sorting: BaseSorting, num_samples: list[int] | None = None, - method: Literal["uniform", "all", "percentage", "maximum_rate", "temporal_bins"] = "uniform", + method: Literal["uniform", "all", "percentage", "maximum_rate"] = "uniform", max_spikes_per_unit: int = 500, percentage: float | None = None, maximum_rate: float | None = None, - bin_size_s: float | None = None, - k_per_bin: int | None = None, margin_size: int | None = None, seed: int | None = None, ): @@ -177,7 +175,7 @@ def random_spikes_selection( The number of samples per segment. Can be retrieved from recording with num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] - method: "uniform" | "percentage" | "maximum_rate" | "all" | "temporal_bins", default: "uniform" + method: "uniform" | "percentage" | "maximum_rate" | "all" , default: "uniform" The method to use. max_spikes_per_unit: int, default: 500 The maximum number of spikes per units @@ -185,10 +183,6 @@ def random_spikes_selection( In case of `percentage` method. The proportion of spikes per units. maximum_rate: float | None, default: None In case of `maximum_rate` method. The cap rate per units. - bin_size_s: float | None, default: None - In case of `temporal_bins` method. The duration of a temporal bin. - k_per_bin: int | None, default: None - In case of `temporal_bins` method. Maximum number of spikes per bins. margin_size: None | int, default: None A margin on each border of segments to avoid border spikes seed: None | int, default: None @@ -199,7 +193,7 @@ def random_spikes_selection( random_spikes_indices: np.array Selected spike indices coresponding to the sorting spike vector. """ - rng_methods = ("uniform", "percentage", "maximum_rate", "temporal_bins") + rng_methods = ("uniform", "percentage", "maximum_rate") if method == "all": spikes = sorting.to_spike_vector() @@ -257,42 +251,6 @@ def random_spikes_selection( rng_size = min(int(t_duration * maximum_rate), max_spikes_per_unit, all_unit_indices.size) selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False) - elif method == "temporal_bins": - # expressed bin sampling as a dual sub sorting problem to be fully vectorized - - if None in (k_per_bin, bin_size_s): - missing = [] - if k_per_bin is None: - missing.append("k_per_bin") - if bin_size_s is None: - missing.append("bin_size_s") - raise ValueError( - f"the following args need to be defined when using the 'temporal bins' method : {', '.join(missing)}" - ) - - sampling_frequency = sorting.get_sampling_frequency() - bin_size_freq = int(bin_size_s * sampling_frequency) - - unit_spikes = np.concat(spikes)[all_unit_indices] - - # local to segment so will loop and reset - bin_index = unit_spikes["sample_index"] // bin_size_freq - segment_index = unit_spikes["segment_index"] - - group_values = np.stack((segment_index, bin_index), axis=1) - _, group_keys = np.unique(group_values, return_inverse=True, axis=0) - - score = rng.random(all_unit_indices.size) - order = np.lexsort((score, group_keys)) - - ordered_unit_indices = all_unit_indices[order] - - group_start = np.r_[0, np.flatnonzero(np.diff(group_keys)) + 1] - counts = np.diff(np.r_[group_start, ordered_unit_indices.size]) - - ranks = np.arange(ordered_unit_indices.size, step=1) - np.repeat(group_start, counts) - selection_mask = ranks <= k_per_bin - selected_unit_indices = ordered_unit_indices[selection_mask] random_spikes_indices.append(selected_unit_indices) From d49f9d456f06c9cae7330b48cc7a0313cb44a8e7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Mar 2026 15:38:45 +0000 Subject: [PATCH 10/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sorting_tools.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 17fc830f79..ebec4a649d 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -251,7 +251,6 @@ def random_spikes_selection( rng_size = min(int(t_duration * maximum_rate), max_spikes_per_unit, all_unit_indices.size) selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False) - random_spikes_indices.append(selected_unit_indices) random_spikes_indices = np.concatenate(random_spikes_indices) From d373856b141ce6c3288b2ee2763f004ddfb76d6b Mon Sep 17 00:00:00 2001 From: tayheau Date: Fri, 6 Mar 2026 16:46:56 +0100 Subject: [PATCH 11/14] propagated args to `ComputeRandomSpikes` --- src/spikeinterface/core/analyzer_extension_core.py | 8 ++++++-- src/spikeinterface/core/sorting_tools.py | 12 ++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 43f18e8d63..655edb89ec 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -35,7 +35,7 @@ class ComputeRandomSpikes(AnalyzerExtension): Parameters ---------- - method : "uniform" | "all", default: "uniform" + method: "uniform" | "percentage" | "maximum_rate" | "all" , default: "uniform" The method to select the spikes max_spikes_per_unit : int, default: 500 The maximum number of spikes per unit, ignored if method="all" @@ -43,6 +43,10 @@ class ComputeRandomSpikes(AnalyzerExtension): A margin on each border of segments to avoid border spikes, ignored if method="all" seed : int or None, default: None A seed for the random generator, ignored if method="all" + percentage: float | None, default: None + In case of `percentage` method. The proportion of spikes per units. + maximum_rate: float | None, default: None + In case of `maximum_rate` method. The cap rate per units. Returns ------- @@ -64,7 +68,7 @@ def _run(self, verbose=False): **self.params, ) - def _set_params(self, method="uniform", max_spikes_per_unit=500, margin_size=None, seed=None): + def _set_params(self, method="uniform", max_spikes_per_unit=500, margin_size=None, seed=None, percentage=None, maximum_rate=None): params = dict(method=method, max_spikes_per_unit=max_spikes_per_unit, margin_size=margin_size, seed=seed) return params diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index ebec4a649d..3ec0c37941 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -156,10 +156,10 @@ def random_spikes_selection( num_samples: list[int] | None = None, method: Literal["uniform", "all", "percentage", "maximum_rate"] = "uniform", max_spikes_per_unit: int = 500, - percentage: float | None = None, - maximum_rate: float | None = None, margin_size: int | None = None, seed: int | None = None, + percentage: float | None = None, + maximum_rate: float | None = None, ): """ This replaces `select_random_spikes_uniformly()`. @@ -179,14 +179,14 @@ def random_spikes_selection( The method to use. max_spikes_per_unit: int, default: 500 The maximum number of spikes per units - percentage: float | None, default: None - In case of `percentage` method. The proportion of spikes per units. - maximum_rate: float | None, default: None - In case of `maximum_rate` method. The cap rate per units. margin_size: None | int, default: None A margin on each border of segments to avoid border spikes seed: None | int, default: None A seed for random generator + percentage: float | None, default: None + In case of `percentage` method. The proportion of spikes per units. + maximum_rate: float | None, default: None + In case of `maximum_rate` method. The cap rate per units. Returns ------- From f1315d6cc1aefa60a9cfe486b67e2a76447db17a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Mar 2026 15:55:23 +0000 Subject: [PATCH 12/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/analyzer_extension_core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 655edb89ec..a03239d954 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -68,7 +68,9 @@ def _run(self, verbose=False): **self.params, ) - def _set_params(self, method="uniform", max_spikes_per_unit=500, margin_size=None, seed=None, percentage=None, maximum_rate=None): + def _set_params( + self, method="uniform", max_spikes_per_unit=500, margin_size=None, seed=None, percentage=None, maximum_rate=None + ): params = dict(method=method, max_spikes_per_unit=max_spikes_per_unit, margin_size=margin_size, seed=seed) return params From f000f2d84cb92ef9e7394d71650389cb11fea0ae Mon Sep 17 00:00:00 2001 From: tayheau Date: Fri, 6 Mar 2026 17:08:27 +0100 Subject: [PATCH 13/14] better description --- src/spikeinterface/core/analyzer_extension_core.py | 2 +- src/spikeinterface/core/sorting_tools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index a03239d954..343082b62f 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -36,7 +36,7 @@ class ComputeRandomSpikes(AnalyzerExtension): Parameters ---------- method: "uniform" | "percentage" | "maximum_rate" | "all" , default: "uniform" - The method to select the spikes + Method to select spikes: "uniform" randomly up to max_spikes_per_unit, "percentage" selects a fraction of spikes, and "maximum_rate" limits selection by spike rate over time. max_spikes_per_unit : int, default: 500 The maximum number of spikes per unit, ignored if method="all" margin_size : int, default: None diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 3ec0c37941..4fff7494d8 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -176,7 +176,7 @@ def random_spikes_selection( Can be retrieved from recording with num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] method: "uniform" | "percentage" | "maximum_rate" | "all" , default: "uniform" - The method to use. + Method to select spikes: "uniform" randomly up to max_spikes_per_unit, "percentage" selects a fraction of spikes, and "maximum_rate" limits selection by spike rate over time. max_spikes_per_unit: int, default: 500 The maximum number of spikes per units margin_size: None | int, default: None From 8e34274e8e25083aafd31f434d53ea21bbd25d69 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 9 Mar 2026 16:32:18 +0100 Subject: [PATCH 14/14] Apply suggestion from @chrishalcrow Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> --- src/spikeinterface/widgets/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 000193bb43..1971a32ff8 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -407,7 +407,6 @@ def get_segment_durations(sorting: BaseSorting, segment_indices: list[int] = Non ---------- sorting : BaseSorting The sorting object containing spike data - segment_indices : list[int] | None List of the segment indices to process. Default to None.