Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0c7d9a2
Create publish-book.yml
MMathisLab Jun 12, 2025
ad9836e
Create python-package.yml
MMathisLab Jun 12, 2025
12d5465
Merge branch 'main' into MMathisLab-actions
CeliaBenquet Jun 20, 2025
a17b954
Add makefile
CeliaBenquet Jun 20, 2025
19f1dbf
Update requirements.txt
CeliaBenquet Jun 20, 2025
0834687
Fix tests, run formater and codespell checks
CeliaBenquet Jun 20, 2025
9669054
Add gitactions for formatting and codespell
CeliaBenquet Jun 20, 2025
cb28ccc
Update title of visual notebook
CeliaBenquet Jun 20, 2025
be26f03
Add the shell in yml git action file
CeliaBenquet Jun 20, 2025
c3d11e2
Remove yapf on notebooks
CeliaBenquet Jun 20, 2025
89b9ac7
Fix tests and add more
CeliaBenquet Jun 20, 2025
8abc266
Run formatter
CeliaBenquet Jun 20, 2025
d6a02e2
Fix imports with matplotlib
CeliaBenquet Jun 20, 2025
382ac10
Adapt so that runs for unified CEBRA and simplification
CeliaBenquet Jun 24, 2025
f049ca0
Merge main into celia/test-unified-cebra
CeliaBenquet Jul 3, 2025
6121296
Fix typos
CeliaBenquet Jul 3, 2025
c77a1f5
Fix tests
CeliaBenquet Jul 3, 2025
c03d831
Fix some unified-CEBRA related issues
CeliaBenquet Jul 3, 2025
575cb4b
Fix plots
CeliaBenquet Jul 3, 2025
2aa0e19
Fix formatting
CeliaBenquet Jul 3, 2025
af70639
fix(batch): prevent overwriting batch outputs
Jul 21, 2025
9750078
Add binary decoding and correct axis name for plot, fix plot per layers
Jul 22, 2025
5beff0c
cleaned-up requirements.txt & small fixes to pass the test & Mock KNN…
Jul 22, 2025
2482830
Merge pull request #58 from AdaptiveMotorControlLab/ananda/activation…
CeliaBenquet Jul 24, 2025
01765ce
add some comments, removed is_binary_label flag
Jul 28, 2025
fdcd2e5
Fix: skip raw input decoding for UnifiedSolver to avoid duplicating l…
Jul 31, 2025
54b3b40
Merge branch 'celia/test-unified-cebra' into ananda/binary-decoding-a…
CeliaBenquet Dec 9, 2025
9cee035
Merge pull request #59 from AdaptiveMotorControlLab/ananda/binary-dec…
CeliaBenquet Dec 11, 2025
48ad9fb
Merge branch 'main' into celia/test-unified-cebra
MMathisLab Apr 14, 2026
94b39ae
Fix: Prevent AttributeError when checking for UnifiedSolver in transf…
anandawolz Apr 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cebra_lens/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .utils_allen import *
from .utils_hpc import *
from .utils_plot import *
from .utils_wrapper import *

# selects what files can be imported when doing from CEBRA_Lens import * --> keep env clean
# __all__ = ['get_layer_activations']
157 changes: 106 additions & 51 deletions cebra_lens/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import torch
import torch.nn as nn

from .utils_plot import plot_activations
from cebra_lens import utils_wrapper



def _cut_array(array: npt.NDArray,
Expand All @@ -29,16 +30,16 @@ def _cut_array(array: npt.NDArray,
The sliced array. If both start and end indices are 0, the whole array is returned.
"""

start = cut_indices[0]
end = cut_indices[1]
start, end = cut_indices
if start == 0 and end == 0:
# If both start and end are 0, take the whole array
sliced_array = array
else:
# Otherwise, slice the array
sliced_array = array[:, start:end if end != 0 else start:]
return sliced_array
return array

end_idx = None if end == 0 else end

# Construct a slicing tuple like [:, :, ..., start:end_idx] to slice along the last axis
slicers = [slice(None)] * (array.ndim - 1) + [slice(start, end_idx)]
return array[tuple(slicers)]

def get_cut_indices(
model_: cebra.integrations.sklearn.cebra.CEBRA,
Expand Down Expand Up @@ -90,8 +91,10 @@ def get_cut_indices(
def get_activations_model(
model: cebra.integrations.sklearn.cebra.CEBRA,
data: torch.Tensor,
session_id: int = -1,
name: str = "single",
labels: Optional[torch.Tensor] = None,
session_id: int = None,
pad_before_transform: bool = True,
activations_keys_prefix: str = "model",
instance: int = 0,
layer_type: Type[nn.Module] = nn.Conv1d,
) -> Dict[str, npt.NDArray]:
Expand Down Expand Up @@ -124,53 +127,69 @@ def get_activations_model(

activations = {}
transform_kwargs = {}
if model.solver_name_ in [
"multi-session",
"multi-session-aux",
"multiobjective-solver",
]:

model_ = model.model_[session_id]
transform_kwargs.update({"session_id": session_id})

elif model.solver_name_ in [
"single-session",
"single-session-aux",
"single-session-hybrid",
"single-session-full",
]:
model_ = model.model_

if isinstance(model, cebra.integrations.sklearn.cebra.CEBRA):
model_ = model.solver_._get_model(session_id=session_id)
elif isinstance(model, cebra.solver.Solver):
model_ = model._get_model(session_id=session_id)
else:
raise NotImplementedError(
f"Solver {model.solver_name_} is not yet implemented.")
raise ValueError(
"Model must be an instance of cebra.integrations.sklearn.cebra.CEBRA "
f"or cebra.solver.Solver, got {type(model)} instead.", )

transform_kwargs.update({"session_id": session_id})

activations, handles, conv_layer_info = _attach_hooks(
activations=activations,
model=model_,
name=name,
activations_keys_prefix=activations_keys_prefix,
instance=instance,
layer_type=layer_type,
)
_ = model.transform(data, **transform_kwargs)

_ = utils_wrapper.transform(model=model,
data=data,
label=labels,
**transform_kwargs)

# remove all handles to avoid activation's problems
for handle in handles:
handle.remove()

if model.pad_before_transform:
# Padding logic: calculate the total reduction which happens based on the kernel size per layer, divide the reduction per layer into 2 parts

if hasattr(model, "pad_before_transform"):
pad_before_transform = model.pad_before_transform

if pad_before_transform:
cut_indices = get_cut_indices(model_, layer_type, conv_layer_info)
for i, (key, value) in enumerate(activations.items()):
activations[key] = _cut_array(value, cut_indices[i])
else:
cut_indices = [(0,0)] * len(handles)
# for any activation that was captures in time chunks:
# remove the padding from each chunk using cut indices,
# then, concatenate them along the time axis
for i, (key, batch_list) in enumerate(list(activations.items())):
if not isinstance(batch_list, list):
continue
sliced_chunks = [
_cut_array(chunk, cut_indices[i])
for chunk in batch_list
]
# now every chunk.shape == (1, channels, time)
axis = sliced_chunks[0].ndim - 1
activations[key] = np.concatenate(sliced_chunks, axis=axis)
# squeeze (1, channels, time) to (channels, time), so downstream tools (e.g., k‑NN regression) receive the 2D array they require
for key, arr in list(activations.items()):
if arr.ndim == 3 and arr.shape[0] == 1:
activations[key] = arr[0]

return activations


def process_activations(
models: Dict[str, List[cebra.integrations.sklearn.cebra.CEBRA]],
data: torch.Tensor,
session_id: int,
labels: Optional[torch.Tensor] = None,
session_id: int = None,
pad_before_transform: bool = True,
activations: Dict[str, npt.NDArray] = {},
layer_type: Type[nn.Module] = None,
) -> Dict[str, npt.NDArray]:
Expand Down Expand Up @@ -202,8 +221,10 @@ def process_activations(
get_activations_model(
model=model,
data=data,
labels=labels,
session_id=session_id,
name=model_name,
pad_before_transform=pad_before_transform,
activations_keys_prefix=model_name,
instance=i,
layer_type=layer_type,
))
Expand All @@ -212,18 +233,37 @@ def process_activations(


# Function to create a hook that stores the activations in the dictionary
def _get_activation(name: str, activations: Dict):

def _get_activation(activations_keys_prefix: str, activations: Dict):
"""Creates a forward hook to capture activations from a model layer.

This function returns a hook that captures the output of a model layer during the forward pass and stores it in a dictionary.

Args:
activations_keys_prefix : str
The prefix to use for the activation key, corresponding to the type of model (eg. "single", "multi").
activations : Dict
A dictionary to store the activations. The key will be the name of the layer, and the value will be the activations.

Returns:
hook : function
A forward hook function that captures the activations.
activations : Dict
The dictionary where the activations will be stored. The key is the name of the layer, and the value is the activations.
"""
activations.setdefault(activations_keys_prefix, [])
def hook(model, input, output):
activations[name] = output.detach().squeeze().numpy()
arr = output.detach().cpu().numpy()
activations[activations_keys_prefix].append(arr)

return hook, activations


#NOTE(celia): this function is not super flexible to handle different layer types,
# but it is a good starting point.
def _attach_hooks(
activations: Dict[str, npt.NDArray],
model: cebra.integrations.sklearn.cebra.CEBRA,
name: str,
activations_keys_prefix: str,
instance: int,
layer_type: Type[nn.Module] = None,
) -> Dict[str, npt.NDArray]: # only attaches hooks on convolutional layers
Expand All @@ -237,8 +277,10 @@ def _attach_hooks(
A dictionary to store the activations. Please refer to ``activations`` returned by ``get_activations_model``.
model : cebra.integrations.sklearn.cebra.CEBRA
The model to which hooks will be attached.
name : str
A base name for the activation keys (e.g., "single", "multi").
activations_keys_prefix : str
A base name for the activation keys (e.g., "single", "multi") so that the keys are
unique for each model instance. The keys will be in the format
'{activations_keys_prefix}_{instance}_layer_{num_layer}'.
instance : int
The instance number for the model, used to differentiate between models from the same model category.
layer_type : Type[nn.Module]
Expand All @@ -258,8 +300,10 @@ def _attach_hooks(
# attach hook to the layer_type and to the output layer
if isinstance(model.net[i], layer_type) or i == len(model.net) - 1:
hook, activations = _get_activation(
f"{name}_{instance}_layer_{num_layer}", activations)
if isinstance(model.net[i], layer_type):
f"{activations_keys_prefix}_{instance}_layer_{num_layer}",
activations)
if isinstance(model.net[i],
layer_type) and layer_type == nn.Conv1d:
conv_layer_info.append(model.net[i].kernel_size[0])
handle = model.net[i].register_forward_hook(hook)
handles.append(handle)
Expand All @@ -269,7 +313,7 @@ def _attach_hooks(
for submodule in model.net[i].modules():
if isinstance(submodule, layer_type):
hook, activations = _get_activation(
f"{name}_{instance}_layer_{num_layer}",
f"{activations_keys_prefix}_{instance}_layer_{num_layer}",
activations,
)
conv_layer_info.append(submodule.kernel_size[0])
Expand All @@ -284,7 +328,7 @@ def _attach_hooks(
if bool(model.net[i]._modules):
for submodule in model.net[i].modules():
hook, activations = _get_activation(
f"{name}_{instance}_layer_{num_layer}",
f"{activations_keys_prefix}_{instance}_layer_{num_layer}",
activations,
)
handle = submodule.register_forward_hook(hook)
Expand All @@ -293,7 +337,8 @@ def _attach_hooks(

else:
hook, activations = _get_activation(
f"{name}_{instance}_layer_{num_layer}", activations)
f"{activations_keys_prefix}_{instance}_layer_{num_layer}",
activations)

handle = model.net[i].register_forward_hook(hook)
handles.append(handle)
Expand Down Expand Up @@ -345,9 +390,11 @@ def aggregate_activations(
def get_activations(
models: Dict[str, List[cebra.integrations.sklearn.cebra.CEBRA]],
data: torch.Tensor,
session_id: int,
labels: Optional[torch.Tensor] = None,
session_id: int = None,
pad_before_transform: bool = True,
activations: Optional[Dict[str, npt.NDArray]] = None,
layer_type: Optional[Type[nn.Module]] = None,
layer_type: Optional[Type[nn.Module]] = nn.Conv1d,
) -> Dict[str, npt.NDArray]:
"""Extract and organize activations from models.

Expand All @@ -370,7 +417,15 @@ def get_activations(
activations = activations or {}

aggregated_activations = aggregate_activations(
process_activations(models, data, session_id, activations, layer_type))
process_activations(
models=models,
data=data,
labels=labels,
session_id=session_id,
pad_before_transform=pad_before_transform,
activations=activations,
layer_type=layer_type,
))

activations_dict = {}
for key, value in aggregated_activations.items():
Expand Down
7 changes: 6 additions & 1 deletion cebra_lens/quantification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def iterate_over_layers(
self,
activations: List[Union[float, npt.NDArray]],
metric_func: types.FunctionType,
**kwargs,
) -> List[Union[np.float64, npt.NDArray]]:
"""Iterate over each layer of activations and apply the metric function to compute the desired metric.

Expand All @@ -46,7 +47,7 @@ def iterate_over_layers(
"""
layer_data = []
for layer_activation in activations:
layer_data.append(metric_func(layer_activation))
layer_data.append(metric_func(layer_activation, **kwargs))
return layer_data

def save(self, filepath: str, data: Dict[str, npt.NDArray]) -> None:
Expand Down Expand Up @@ -88,3 +89,7 @@ def plot(self):
The plot function is specific to a metric, e.g. intra-bin distance, RDM, CKA,...
"""
raise NotImplementedError

def output_information(self):
"""Output information about the metric class."""
print(f"Metric class: {self.__class__.__name__}")
Loading
Loading