diff --git a/bindsnet/evaluation/evaluation.py b/bindsnet/evaluation/evaluation.py index 5271d762..77ba1b21 100644 --- a/bindsnet/evaluation/evaluation.py +++ b/bindsnet/evaluation/evaluation.py @@ -44,8 +44,9 @@ def assign_labels( indices = torch.nonzero(labels == i).view(-1) # Compute average firing rates for this label. + selected_spikes = torch.index_select(spikes, dim=0, index=torch.tensor(indices)) rates[:, i] = alpha * rates[:, i] + ( - torch.sum(spikes[indices], 0) / n_labeled + torch.sum(selected_spikes, 0) / n_labeled ) # Compute proportions of spike activity per class. @@ -111,6 +112,8 @@ def all_activity( # Sum over time dimension (spike ordering doesn't matter). spikes = spikes.sum(1) + if spikes.is_sparse: + spikes = spikes.to_dense() rates = torch.zeros((n_samples, n_labels), device=spikes.device) for i in range(n_labels): @@ -152,6 +155,8 @@ def proportion_weighting( # Sum over time dimension (spike ordering doesn't matter). spikes = spikes.sum(1) + if spikes.is_sparse: + spikes = spikes.to_dense() rates = torch.zeros((n_samples, n_labels), device=spikes.device) for i in range(n_labels): diff --git a/bindsnet/learning/MCC_learning.py b/bindsnet/learning/MCC_learning.py index 14565a80..66760724 100644 --- a/bindsnet/learning/MCC_learning.py +++ b/bindsnet/learning/MCC_learning.py @@ -102,7 +102,10 @@ def update(self, **kwargs) -> None: if ((self.min is not None) or (self.max is not None)) and not isinstance( self, NoOp ): - self.feature_value.clamp_(self.min, self.max) + if self.feature_value.is_sparse: + self.feature_value = self.feature_value.to_dense().clamp_(self.min, self.max).to_sparse() + else: + self.feature_value.clamp_(self.min, self.max) @abstractmethod def reset_state_variables(self) -> None: @@ -247,10 +250,16 @@ def _connection_update(self, **kwargs) -> None: torch.mean(self.average_buffer_pre, dim=0) * self.connection.dt ) else: - self.feature_value -= ( - self.reduction(torch.bmm(source_s, target_x), dim=0) - * self.connection.dt - ) + if self.feature_value.is_sparse: + self.feature_value -= ( + torch.bmm(source_s, target_x) + * self.connection.dt + ).to_sparse() + else: + self.feature_value -= ( + self.reduction(torch.bmm(source_s, target_x), dim=0) + * self.connection.dt + ) del source_s, target_x # Post-synaptic update. @@ -278,10 +287,16 @@ def _connection_update(self, **kwargs) -> None: torch.mean(self.average_buffer_post, dim=0) * self.connection.dt ) else: - self.feature_value += ( - self.reduction(torch.bmm(source_x, target_s), dim=0) - * self.connection.dt - ) + if self.feature_value.is_sparse: + self.feature_value += ( + torch.bmm(source_x, target_s) + * self.connection.dt + ).to_sparse() + else: + self.feature_value += ( + self.reduction(torch.bmm(source_x, target_s), dim=0) + * self.connection.dt + ) del source_x, target_s super().update() diff --git a/bindsnet/models/models.py b/bindsnet/models/models.py index 8ae3f136..f0463410 100644 --- a/bindsnet/models/models.py +++ b/bindsnet/models/models.py @@ -4,11 +4,14 @@ import torch from scipy.spatial.distance import euclidean from torch.nn.modules.utils import _pair +from torch import device from bindsnet.learning import PostPre +from bindsnet.learning.MCC_learning import PostPre as MMCPostPre from bindsnet.network import Network from bindsnet.network.nodes import DiehlAndCookNodes, Input, LIFNodes -from bindsnet.network.topology import Connection, LocalConnection +from bindsnet.network.topology import Connection, LocalConnection, MulticompartmentConnection +from bindsnet.network.topology_features import Weight class TwoLayerNetwork(Network): @@ -94,6 +97,9 @@ class DiehlAndCook2015(Network): def __init__( self, n_inpt: int, + device: device, + batch_size: int, + sparse: bool = False, n_neurons: int = 100, exc: float = 22.5, inh: float = 17.5, @@ -169,28 +175,61 @@ def __init__( ) # Connections - w = 0.3 * torch.rand(self.n_inpt, self.n_neurons) - input_exc_conn = Connection( + if sparse: + w = 0.3 * torch.rand(batch_size, self.n_inpt, self.n_neurons) + else: + w = 0.3 * torch.rand(self.n_inpt, self.n_neurons) + input_exc_conn = MulticompartmentConnection( source=input_layer, target=exc_layer, - w=w, - update_rule=PostPre, - nu=nu, - reduction=reduction, - wmin=wmin, - wmax=wmax, - norm=norm, + device=device, + pipeline=[ + Weight( + 'weight', + w, + range=[wmin, wmax], + norm=norm, + reduction=reduction, + nu=nu, + learning_rule=MMCPostPre, + sparse=sparse + ) + ] ) w = self.exc * torch.diag(torch.ones(self.n_neurons)) - exc_inh_conn = Connection( - source=exc_layer, target=inh_layer, w=w, wmin=0, wmax=self.exc + if sparse: + w = w.unsqueeze(0).expand(batch_size, -1, -1) + exc_inh_conn = MulticompartmentConnection( + source=exc_layer, + target=inh_layer, + device=device, + pipeline=[ + Weight( + 'weight', + w, + range=[0, self.exc], + sparse=sparse + ) + ] ) w = -self.inh * ( torch.ones(self.n_neurons, self.n_neurons) - torch.diag(torch.ones(self.n_neurons)) ) - inh_exc_conn = Connection( - source=inh_layer, target=exc_layer, w=w, wmin=-self.inh, wmax=0 + if sparse: + w = w.unsqueeze(0).expand(batch_size, -1, -1) + inh_exc_conn = MulticompartmentConnection( + source=inh_layer, + target=exc_layer, + device=device, + pipeline=[ + Weight( + 'weight', + w, + range=[-self.inh, 0], + sparse=sparse + ) + ] ) # Add to network diff --git a/bindsnet/network/monitors.py b/bindsnet/network/monitors.py index f11a2339..dc8e9d94 100644 --- a/bindsnet/network/monitors.py +++ b/bindsnet/network/monitors.py @@ -45,6 +45,7 @@ def __init__( time: Optional[int] = None, batch_size: int = 1, device: str = "cpu", + sparse: Optional[bool] = False ): # language=rst """ @@ -62,6 +63,7 @@ def __init__( self.time = time self.batch_size = batch_size self.device = device + self.sparse = sparse # if time is not specified the monitor variable accumulate the logs if self.time is None: @@ -98,11 +100,12 @@ def record(self) -> None: for v in self.state_vars: data = getattr(self.obj, v).unsqueeze(0) # self.recording[v].append(data.detach().clone().to(self.device)) - self.recording[v].append( - torch.empty_like(data, device=self.device, requires_grad=False).copy_( - data, non_blocking=True - ) + record = torch.empty_like(data, device=self.device, requires_grad=False).copy_( + data, non_blocking=True ) + if self.sparse: + record = record.to_sparse() + self.recording[v].append(record) # remove the oldest element (first in the list) if self.time is not None: self.recording[v].pop(0) diff --git a/bindsnet/network/topology.py b/bindsnet/network/topology.py index 0f9f047c..442e9a15 100644 --- a/bindsnet/network/topology.py +++ b/bindsnet/network/topology.py @@ -450,7 +450,11 @@ def compute(self, s: torch.Tensor) -> torch.Tensor: if conn_spikes.is_sparse: conn_spikes = conn_spikes.to_dense() conn_spikes = conn_spikes.view(s.size(0), self.source.n, self.target.n) - out_signal = conn_spikes.sum(1) + + if conn_spikes.is_sparse: + out_signal = conn_spikes.to_dense().sum(1) + else: + out_signal = conn_spikes.sum(1) if self.traces: self.activity = out_signal diff --git a/examples/mnist/batch_eth_mnist.py b/examples/mnist/batch_eth_mnist.py index 8338af19..26271b31 100644 --- a/examples/mnist/batch_eth_mnist.py +++ b/examples/mnist/batch_eth_mnist.py @@ -44,7 +44,8 @@ parser.add_argument("--test", dest="train", action="store_false") parser.add_argument("--plot", dest="plot", action="store_true") parser.add_argument("--gpu", dest="gpu", action="store_true") -parser.set_defaults(plot=True, gpu=True) +parser.add_argument("--sparse", dest="sparse", action="store_true") +parser.set_defaults(gpu=True) args = parser.parse_args() @@ -66,6 +67,7 @@ train = args.train plot = args.plot gpu = args.gpu +sparse = args.sparse update_steps = int(n_train / batch_size / n_updates) update_interval = update_steps * batch_size @@ -93,6 +95,9 @@ # Build network. network = DiehlAndCook2015( + device=device, + sparse=sparse, + batch_size=batch_size, n_inpt=784, n_neurons=n_neurons, exc=exc, @@ -142,7 +147,7 @@ spikes = {} for layer in set(network.layers): spikes[layer] = Monitor( - network.layers[layer], state_vars=["s"], time=int(time / dt), device=device + network.layers[layer], state_vars=["s"], time=int(time / dt), device=device, sparse=True ) network.add_monitor(spikes[layer], name="%s_spikes" % layer) @@ -160,7 +165,8 @@ perf_ax = None voltage_axes, voltage_ims = None, None -spike_record = torch.zeros((update_interval, int(time / dt), n_neurons), device=device) +spike_record = [torch.zeros((batch_size, int(time / dt), n_neurons), device=device).to_sparse() for _ in range(update_interval // batch_size)] +spike_record_idx = 0 # Train the network. print("\nBegin training...") @@ -192,12 +198,13 @@ # Convert the array of labels into a tensor label_tensor = torch.tensor(labels, device=device) + spike_record_tensor = torch.cat(spike_record, dim=0) # Get network predictions. all_activity_pred = all_activity( - spikes=spike_record, assignments=assignments, n_labels=n_classes + spikes=spike_record_tensor, assignments=assignments, n_labels=n_classes ) proportion_pred = proportion_weighting( - spikes=spike_record, + spikes=spike_record_tensor, assignments=assignments, proportions=proportions, n_labels=n_classes, @@ -235,7 +242,7 @@ # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( - spikes=spike_record, + spikes=spike_record_tensor, labels=label_tensor, n_labels=n_classes, rates=rates, @@ -256,11 +263,10 @@ # Add to spikes recording. s = spikes["Ae"].get("s").permute((1, 0, 2)) - spike_record[ - (step * batch_size) - % update_interval : (step * batch_size % update_interval) - + s.size(0) - ] = s + spike_record[spike_record_idx] = s + spike_record_idx += 1 + if spike_record_idx == len(spike_record): + spike_record_idx = 0 # Get voltage recording. exc_voltages = exc_voltage_monitor.get("v")