Skip to content
80 changes: 57 additions & 23 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def create_sorting_analyzer(
return sorting_analyzer


def load_sorting_analyzer(folder, load_extensions=True, format="auto", backend_options=None) -> "SortingAnalyzer":
def load_sorting_analyzer(
folder, load_extensions=True, format="auto", backend_options=None, lazy=False
) -> "SortingAnalyzer":
"""
Load a SortingAnalyzer object from disk.

Expand Down Expand Up @@ -245,7 +247,9 @@ def load_sorting_analyzer(folder, load_extensions=True, format="auto", backend_o
The loaded SortingAnalyzer

"""
return SortingAnalyzer.load(folder, load_extensions=load_extensions, format=format, backend_options=backend_options)
return SortingAnalyzer.load(
folder, load_extensions=load_extensions, format=format, backend_options=backend_options, lazy=lazy
)


class SortingAnalyzer:
Expand Down Expand Up @@ -279,6 +283,7 @@ def __init__(
sparsity: ChannelSparsity | None = None,
return_in_uV: bool = True,
backend_options: dict | None = None,
lazy: bool = False,
):
# very fast init because checks are done in load and create
self.sorting = sorting
Expand All @@ -304,6 +309,9 @@ def __init__(
# (additional saving options for creating and saving datasets, e.g. compression/filters for zarr)
self._backend_options = {} if backend_options is None else backend_options

# the lazy flag is used to load the extensions in a lazy way (only when needed)
self._lazy = lazy

# extensions are not loaded at init
self.extensions = dict()

Expand Down Expand Up @@ -407,7 +415,7 @@ def create(
return sorting_analyzer

@classmethod
def load(cls, folder, recording=None, load_extensions=True, format="auto", backend_options=None):
def load(cls, folder, recording=None, load_extensions=True, format="auto", backend_options=None, lazy=False):
"""
Load folder or zarr.
The recording can be given if the recording location has changed.
Expand All @@ -422,14 +430,14 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto", backe

if format == "binary_folder":
sorting_analyzer = SortingAnalyzer.load_from_binary_folder(
folder, recording=recording, backend_options=backend_options
folder, recording=recording, backend_options=backend_options, lazy=lazy
)
elif format == "zarr":
sorting_analyzer = SortingAnalyzer.load_from_zarr(
folder, recording=recording, backend_options=backend_options
folder, recording=recording, backend_options=backend_options, lazy=lazy
)

if not is_path_remote(str(folder)):
if not is_path_remote(str(folder)) and not lazy:
if load_extensions:
sorting_analyzer.load_all_saved_extension()

Expand Down Expand Up @@ -532,15 +540,24 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV
return cls.load_from_binary_folder(folder, recording=recording, backend_options=backend_options)

@classmethod
def load_from_binary_folder(cls, folder, recording=None, backend_options=None):
def load_from_binary_folder(cls, folder, recording=None, backend_options=None, lazy=False):
from .loading import load

folder = Path(folder)
assert folder.is_dir(), f"This folder does not exists {folder}"

# load internal sorting copy in memory
if lazy:
numpy_folder_kwargs = dict(mmap_mode="r")
copy_spike_vector = False
else:
numpy_folder_kwargs = dict()
copy_spike_vector = True

sorting = NumpySorting.from_sorting(
NumpyFolderSorting(folder / "sorting"), with_metadata=True, copy_spike_vector=True
NumpyFolderSorting(folder / "sorting", **numpy_folder_kwargs),
with_metadata=True,
copy_spike_vector=copy_spike_vector,
)

# Try to load the recording if not provided
Expand Down Expand Up @@ -601,6 +618,7 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None):
sparsity=sparsity,
return_in_uV=return_in_uV,
backend_options=backend_options,
lazy=lazy,
)
sorting_analyzer.folder = folder

Expand Down Expand Up @@ -698,7 +716,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att
return cls.load_from_zarr(folder, recording=recording, backend_options=backend_options)

@classmethod
def load_from_zarr(cls, folder, recording=None, backend_options=None):
def load_from_zarr(cls, folder, recording=None, backend_options=None, lazy=False):
import zarr
from .loading import load

Expand All @@ -721,11 +739,22 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None):
"Please consider re-generating the SortingAnalyzer object."
)

# load internal sorting in memory
if lazy:
copy_spike_vector = False
lazy_spike_vector = True
else:
copy_spike_vector = True
lazy_spike_vector = False

sorting = NumpySorting.from_sorting(
ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options),
ZarrSortingExtractor(
folder,
zarr_group="sorting",
storage_options=storage_options,
lazy_spike_vector=lazy_spike_vector,
),
with_metadata=True,
copy_spike_vector=True,
copy_spike_vector=copy_spike_vector,
)

# load recording if possible
Expand Down Expand Up @@ -770,6 +799,7 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None):
sparsity=sparsity,
return_in_uV=return_in_uV,
backend_options=backend_options,
lazy=lazy,
)
sorting_analyzer.folder = folder

Expand Down Expand Up @@ -988,6 +1018,11 @@ def _save_or_select_or_merge_or_split(
new_sorting_analyzer : SortingAnalyzer
The newly created SortingAnalyzer object.
"""
if self._lazy:
raise ValueError(
"Cannot save, select, merge or split units when the SortingAnalyzer is lazy. "
"Please load the SortingAnalyzer with lazy=False."
)
if self.has_recording():
recording = self._recording
elif self.has_temporary_recording():
Expand Down Expand Up @@ -1936,7 +1971,7 @@ def load_extension(self, extension_name: str):
if extension_class is None:
return None

extension_instance = extension_class.load(self)
extension_instance = extension_class.load(self, lazy=self._lazy)

self.extensions[extension_name] = extension_instance

Expand Down Expand Up @@ -2414,20 +2449,20 @@ def _get_zarr_extension_group(self, mode="r+"):
return extension_group

@classmethod
def load(cls, sorting_analyzer):
def load(cls, sorting_analyzer, lazy=False):
ext = cls(sorting_analyzer)
ext.load_params()
ext.load_run_info()
if ext.run_info is not None:
if ext.run_info["run_completed"]:
ext.load_data()
ext.load_data(lazy=lazy)
if cls.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()
if len(ext.data) > 0:
return ext
else:
# this is for back-compatibility of old analyzers
ext.load_data()
ext.load_data(lazy=lazy)
if cls.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()
if len(ext.data) > 0:
Expand Down Expand Up @@ -2527,7 +2562,7 @@ def load_params(self):

self.params = params

def load_data(self):
def load_data(self, lazy=False):
ext_data = None
if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
Expand All @@ -2547,10 +2582,10 @@ def load_data(self):
ext_data = json.load(f)
elif ext_data_file.suffix == ".npy":
# The lazy loading of an extension is complicated because if we compute again
# and have a link to the old buffer on windows then it fails
# ext_data = np.load(ext_data_file, mmap_mode="r")
# so we go back to full loading
ext_data = np.load(ext_data_file)
# and have a link to the old buffer on windows then it fails.
# So, by default, we use full loading, but lazy can be requested on demand.
kwargs = dict(mmap_mode="r") if lazy else dict()
ext_data = np.load(ext_data_file, **kwargs)
elif ext_data_file.suffix == ".csv":
import pandas as pd

Expand Down Expand Up @@ -2586,8 +2621,7 @@ def load_data(self):
elif "object" in ext_data_.attrs:
ext_data = ext_data_[0]
else:
# this load in memmory
ext_data = np.array(ext_data_)
ext_data = ext_data_ if lazy else np.array(ext_data_[:])
self.set_data(ext_data_name, ext_data)

if len(self.data) == 0:
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/core/sortingfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class NumpyFolderSorting(BaseSorting):
mode = "folder"
name = "NumpyFolder"

def __init__(self, folder_path):
def __init__(self, folder_path, mmap_mode=None):
folder_path = Path(folder_path)

with open(folder_path / "numpysorting_info.json", "r") as f:
Expand All @@ -36,7 +36,7 @@ def __init__(self, folder_path):

BaseSorting.__init__(self, sampling_frequency, unit_ids)

self.spikes = np.load(folder_path / "spikes.npy")
self.spikes = np.load(folder_path / "spikes.npy", mmap_mode=mmap_mode)

for segment_index in range(num_segments):
self.add_sorting_segment(SpikeVectorSortingSegment(self.spikes, segment_index, unit_ids))
Expand All @@ -47,7 +47,7 @@ def __init__(self, folder_path):
folder_metadata = folder_path
self.load_metadata_from_folder(folder_metadata)

self._kwargs = dict(folder_path=str(folder_path.absolute()))
self._kwargs = dict(folder_path=str(folder_path.absolute()), mmap_mode=mmap_mode)
Comment thread
alejoe91 marked this conversation as resolved.

@staticmethod
def write_sorting(sorting, save_path):
Expand Down
65 changes: 63 additions & 2 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_SortingAnalyzer_binary_folder(tmp_path, dataset):
assert "number" in sorting_analyzer.sorting.get_property_keys()
sorting_analyzer_reloded = load_sorting_analyzer(folder, format="auto")
assert "quality" in sorting_analyzer_reloded.sorting.get_property_keys()
assert "number" in sorting_analyzer.sorting.get_property_keys()
assert "number" in sorting_analyzer_reloded.sorting.get_property_keys()


def test_SortingAnalyzer_zarr(tmp_path, dataset):
Expand Down Expand Up @@ -201,7 +201,7 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset):
assert "number" in sorting_analyzer.sorting.get_property_keys()
sorting_analyzer_reloded = load_sorting_analyzer(sorting_analyzer.folder, format="auto")
assert "quality" in sorting_analyzer_reloded.sorting.get_property_keys()
assert "number" in sorting_analyzer.sorting.get_property_keys()
assert "number" in sorting_analyzer_reloded.sorting.get_property_keys()


def test_create_by_dict():
Expand Down Expand Up @@ -325,6 +325,67 @@ def test_SortingAnalyzer_interleaved_probegroup(dataset):
assert np.array_equal(recording.get_channel_locations(), sorting_analyzer.get_channel_locations())


def test_load_in_lazy_mode_binary(tmp_path, dataset):
recording, sorting = dataset

folder = tmp_path / "test_SortingAnalyzer_binary_folder"
if folder.exists():
shutil.rmtree(folder)

sorting_analyzer = create_sorting_analyzer(
sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None
)

sorting_analyzer.compute(["random_spikes", "templates", "spike_amplitudes"])
# load in lazy mode and check that spike vector and extension data are memmap
sorting_analyzer_lazy = load_sorting_analyzer(folder, format="auto", lazy=True)

assert isinstance(sorting_analyzer_lazy.sorting.to_spike_vector(), np.memmap)

template_ext = sorting_analyzer_lazy.get_extension("templates")
template_data = template_ext.data
for key, value in template_data.items():
if isinstance(value, np.ndarray):
assert isinstance(value, np.memmap)
spike_amplitudes_ext = sorting_analyzer_lazy.get_extension("spike_amplitudes")
spike_amplitudes_data = spike_amplitudes_ext.data
for key, value in spike_amplitudes_data.items():
if isinstance(value, np.ndarray):
assert isinstance(value, np.memmap)


def test_load_in_lazy_mode_zarr(tmp_path, dataset):
import zarr
from spikeinterface.core.zarrextractors import ZarrSpikeVector

recording, sorting = dataset

folder = tmp_path / "test_SortingAnalyzer_zarr_folder.zarr"
if folder.exists():
shutil.rmtree(folder)

sorting_analyzer = create_sorting_analyzer(
sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None
)

sorting_analyzer.compute(["random_spikes", "templates", "spike_amplitudes"])
# load in lazy mode and check that spikevector is ZarrSpikeVector andextension data are zarr arrays
sorting_analyzer_lazy = load_sorting_analyzer(folder, format="auto", lazy=True)

assert isinstance(sorting_analyzer_lazy.sorting.to_spike_vector(), ZarrSpikeVector)

template_ext = sorting_analyzer_lazy.get_extension("templates")
template_data = template_ext.data
for key, value in template_data.items():
if isinstance(value, np.ndarray):
assert isinstance(value, zarr.Array)
spike_amplitudes_ext = sorting_analyzer_lazy.get_extension("spike_amplitudes")
spike_amplitudes_data = spike_amplitudes_ext.data
for key, value in spike_amplitudes_data.items():
if isinstance(value, np.ndarray):
assert isinstance(value, zarr.Array)


def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):

register_result_extension(DummyAnalyzerExtension)
Expand Down
Loading
Loading