Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion bindsnet/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def assign_labels(
indices = torch.nonzero(labels == i).view(-1)

# Compute average firing rates for this label.
selected_spikes = torch.index_select(spikes, dim=0, index=torch.tensor(indices))
rates[:, i] = alpha * rates[:, i] + (
torch.sum(spikes[indices], 0) / n_labeled
torch.sum(selected_spikes, 0) / n_labeled
)

# Compute proportions of spike activity per class.
Expand Down Expand Up @@ -111,6 +112,8 @@ def all_activity(

# Sum over time dimension (spike ordering doesn't matter).
spikes = spikes.sum(1)
if spikes.is_sparse:
spikes = spikes.to_dense()

rates = torch.zeros((n_samples, n_labels), device=spikes.device)
for i in range(n_labels):
Expand Down Expand Up @@ -152,6 +155,8 @@ def proportion_weighting(

# Sum over time dimension (spike ordering doesn't matter).
spikes = spikes.sum(1)
if spikes.is_sparse:
spikes = spikes.to_dense()

rates = torch.zeros((n_samples, n_labels), device=spikes.device)
for i in range(n_labels):
Expand Down
33 changes: 24 additions & 9 deletions bindsnet/learning/MCC_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def update(self, **kwargs) -> None:
if ((self.min is not None) or (self.max is not None)) and not isinstance(
self, NoOp
):
self.feature_value.clamp_(self.min, self.max)
if self.feature_value.is_sparse:
self.feature_value = self.feature_value.to_dense().clamp_(self.min, self.max).to_sparse()
else:
self.feature_value.clamp_(self.min, self.max)

@abstractmethod
def reset_state_variables(self) -> None:
Expand Down Expand Up @@ -247,10 +250,16 @@ def _connection_update(self, **kwargs) -> None:
torch.mean(self.average_buffer_pre, dim=0) * self.connection.dt
)
else:
self.feature_value -= (
self.reduction(torch.bmm(source_s, target_x), dim=0)
* self.connection.dt
)
if self.feature_value.is_sparse:
self.feature_value -= (
torch.bmm(source_s, target_x)
* self.connection.dt
).to_sparse()
else:
self.feature_value -= (
self.reduction(torch.bmm(source_s, target_x), dim=0)
* self.connection.dt
)
del source_s, target_x

# Post-synaptic update.
Expand Down Expand Up @@ -278,10 +287,16 @@ def _connection_update(self, **kwargs) -> None:
torch.mean(self.average_buffer_post, dim=0) * self.connection.dt
)
else:
self.feature_value += (
self.reduction(torch.bmm(source_x, target_s), dim=0)
* self.connection.dt
)
if self.feature_value.is_sparse:
self.feature_value += (
torch.bmm(source_x, target_s)
* self.connection.dt
).to_sparse()
else:
self.feature_value += (
self.reduction(torch.bmm(source_x, target_s), dim=0)
* self.connection.dt
)
del source_x, target_s

super().update()
Expand Down
67 changes: 53 additions & 14 deletions bindsnet/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
import torch
from scipy.spatial.distance import euclidean
from torch.nn.modules.utils import _pair
from torch import device

from bindsnet.learning import PostPre
from bindsnet.learning.MCC_learning import PostPre as MMCPostPre
from bindsnet.network import Network
from bindsnet.network.nodes import DiehlAndCookNodes, Input, LIFNodes
from bindsnet.network.topology import Connection, LocalConnection
from bindsnet.network.topology import Connection, LocalConnection, MulticompartmentConnection
from bindsnet.network.topology_features import Weight


class TwoLayerNetwork(Network):
Expand Down Expand Up @@ -94,6 +97,9 @@ class DiehlAndCook2015(Network):
def __init__(
self,
n_inpt: int,
device: device,
batch_size: int,
sparse: bool = False,
n_neurons: int = 100,
exc: float = 22.5,
inh: float = 17.5,
Expand Down Expand Up @@ -169,28 +175,61 @@ def __init__(
)

# Connections
w = 0.3 * torch.rand(self.n_inpt, self.n_neurons)
input_exc_conn = Connection(
if sparse:
w = 0.3 * torch.rand(batch_size, self.n_inpt, self.n_neurons)
else:
w = 0.3 * torch.rand(self.n_inpt, self.n_neurons)
input_exc_conn = MulticompartmentConnection(
source=input_layer,
target=exc_layer,
w=w,
update_rule=PostPre,
nu=nu,
reduction=reduction,
wmin=wmin,
wmax=wmax,
norm=norm,
device=device,
pipeline=[
Weight(
'weight',
w,
range=[wmin, wmax],
norm=norm,
reduction=reduction,
nu=nu,
learning_rule=MMCPostPre,
sparse=sparse
)
]
)
w = self.exc * torch.diag(torch.ones(self.n_neurons))
exc_inh_conn = Connection(
source=exc_layer, target=inh_layer, w=w, wmin=0, wmax=self.exc
if sparse:
w = w.unsqueeze(0).expand(batch_size, -1, -1)
exc_inh_conn = MulticompartmentConnection(
source=exc_layer,
target=inh_layer,
device=device,
pipeline=[
Weight(
'weight',
w,
range=[0, self.exc],
sparse=sparse
)
]
)
w = -self.inh * (
torch.ones(self.n_neurons, self.n_neurons)
- torch.diag(torch.ones(self.n_neurons))
)
inh_exc_conn = Connection(
source=inh_layer, target=exc_layer, w=w, wmin=-self.inh, wmax=0
if sparse:
w = w.unsqueeze(0).expand(batch_size, -1, -1)
inh_exc_conn = MulticompartmentConnection(
source=inh_layer,
target=exc_layer,
device=device,
pipeline=[
Weight(
'weight',
w,
range=[-self.inh, 0],
sparse=sparse
)
]
)

# Add to network
Expand Down
11 changes: 7 additions & 4 deletions bindsnet/network/monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
time: Optional[int] = None,
batch_size: int = 1,
device: str = "cpu",
sparse: Optional[bool] = False
):
# language=rst
"""
Expand All @@ -62,6 +63,7 @@ def __init__(
self.time = time
self.batch_size = batch_size
self.device = device
self.sparse = sparse

# if time is not specified the monitor variable accumulate the logs
if self.time is None:
Expand Down Expand Up @@ -98,11 +100,12 @@ def record(self) -> None:
for v in self.state_vars:
data = getattr(self.obj, v).unsqueeze(0)
# self.recording[v].append(data.detach().clone().to(self.device))
self.recording[v].append(
torch.empty_like(data, device=self.device, requires_grad=False).copy_(
data, non_blocking=True
)
record = torch.empty_like(data, device=self.device, requires_grad=False).copy_(
data, non_blocking=True
)
if self.sparse:
record = record.to_sparse()
self.recording[v].append(record)
# remove the oldest element (first in the list)
if self.time is not None:
self.recording[v].pop(0)
Expand Down
6 changes: 5 additions & 1 deletion bindsnet/network/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,11 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
if conn_spikes.is_sparse:
conn_spikes = conn_spikes.to_dense()
conn_spikes = conn_spikes.view(s.size(0), self.source.n, self.target.n)
out_signal = conn_spikes.sum(1)

if conn_spikes.is_sparse:
out_signal = conn_spikes.to_dense().sum(1)
else:
out_signal = conn_spikes.sum(1)

if self.traces:
self.activity = out_signal
Expand Down
28 changes: 17 additions & 11 deletions examples/mnist/batch_eth_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
parser.add_argument("--test", dest="train", action="store_false")
parser.add_argument("--plot", dest="plot", action="store_true")
parser.add_argument("--gpu", dest="gpu", action="store_true")
parser.set_defaults(plot=True, gpu=True)
parser.add_argument("--sparse", dest="sparse", action="store_true")
parser.set_defaults(gpu=True)

args = parser.parse_args()

Expand All @@ -66,6 +67,7 @@
train = args.train
plot = args.plot
gpu = args.gpu
sparse = args.sparse

update_steps = int(n_train / batch_size / n_updates)
update_interval = update_steps * batch_size
Expand Down Expand Up @@ -93,6 +95,9 @@

# Build network.
network = DiehlAndCook2015(
device=device,
sparse=sparse,
batch_size=batch_size,
n_inpt=784,
n_neurons=n_neurons,
exc=exc,
Expand Down Expand Up @@ -142,7 +147,7 @@
spikes = {}
for layer in set(network.layers):
spikes[layer] = Monitor(
network.layers[layer], state_vars=["s"], time=int(time / dt), device=device
network.layers[layer], state_vars=["s"], time=int(time / dt), device=device, sparse=True
)
network.add_monitor(spikes[layer], name="%s_spikes" % layer)

Expand All @@ -160,7 +165,8 @@
perf_ax = None
voltage_axes, voltage_ims = None, None

spike_record = torch.zeros((update_interval, int(time / dt), n_neurons), device=device)
spike_record = [torch.zeros((batch_size, int(time / dt), n_neurons), device=device).to_sparse() for _ in range(update_interval // batch_size)]
spike_record_idx = 0

# Train the network.
print("\nBegin training...")
Expand Down Expand Up @@ -192,12 +198,13 @@
# Convert the array of labels into a tensor
label_tensor = torch.tensor(labels, device=device)

spike_record_tensor = torch.cat(spike_record, dim=0)
# Get network predictions.
all_activity_pred = all_activity(
spikes=spike_record, assignments=assignments, n_labels=n_classes
spikes=spike_record_tensor, assignments=assignments, n_labels=n_classes
)
proportion_pred = proportion_weighting(
spikes=spike_record,
spikes=spike_record_tensor,
assignments=assignments,
proportions=proportions,
n_labels=n_classes,
Expand Down Expand Up @@ -235,7 +242,7 @@

# Assign labels to excitatory layer neurons.
assignments, proportions, rates = assign_labels(
spikes=spike_record,
spikes=spike_record_tensor,
labels=label_tensor,
n_labels=n_classes,
rates=rates,
Expand All @@ -256,11 +263,10 @@

# Add to spikes recording.
s = spikes["Ae"].get("s").permute((1, 0, 2))
spike_record[
(step * batch_size)
% update_interval : (step * batch_size % update_interval)
+ s.size(0)
] = s
spike_record[spike_record_idx] = s
spike_record_idx += 1
if spike_record_idx == len(spike_record):
spike_record_idx = 0

# Get voltage recording.
exc_voltages = exc_voltage_monitor.get("v")
Expand Down
Loading