-
Notifications
You must be signed in to change notification settings - Fork 341
merge with master #722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
merge with master #722
Changes from all commits
0308b27
d1d3e42
471d455
8b1cb4e
12d8e03
2c0760d
9678254
fd0ef9b
ec63566
e0f9bf2
e9dfcbf
bed0623
326d270
95626a9
9ff6f01
3c9528a
717ee17
988cb91
88f01a8
85c8cb1
6e938f4
5833f9d
80ba5e5
a7dc702
1c8cb82
be29c59
731206e
6b6e4b5
08f4bd0
5aceb2e
2adfcba
b5c608b
4c894b7
f5a716d
f19692c
d072219
fd3b904
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -102,7 +102,12 @@ def update(self, **kwargs) -> None: | |||||||||
| if ((self.min is not None) or (self.max is not None)) and not isinstance( | ||||||||||
| self, NoOp | ||||||||||
| ): | ||||||||||
| self.feature_value.clamp_(self.min, self.max) | ||||||||||
| if self.feature_value.is_sparse: | ||||||||||
| self.feature_value = ( | ||||||||||
| self.feature_value.to_dense().clamp_(self.min, self.max).to_sparse() | ||||||||||
| ) | ||||||||||
| else: | ||||||||||
| self.feature_value.clamp_(self.min, self.max) | ||||||||||
|
|
||||||||||
| @abstractmethod | ||||||||||
| def reset_state_variables(self) -> None: | ||||||||||
|
|
@@ -247,10 +252,15 @@ def _connection_update(self, **kwargs) -> None: | |||||||||
| torch.mean(self.average_buffer_pre, dim=0) * self.connection.dt | ||||||||||
| ) | ||||||||||
| else: | ||||||||||
| self.feature_value -= ( | ||||||||||
| self.reduction(torch.bmm(source_s, target_x), dim=0) | ||||||||||
| * self.connection.dt | ||||||||||
| ) | ||||||||||
| if self.feature_value.is_sparse: | ||||||||||
| self.feature_value -= ( | ||||||||||
| torch.bmm(source_s, target_x) * self.connection.dt | ||||||||||
| ).to_sparse() | ||||||||||
|
Comment on lines
+255
to
+258
|
||||||||||
| else: | ||||||||||
| self.feature_value -= ( | ||||||||||
| self.reduction(torch.bmm(source_s, target_x), dim=0) | ||||||||||
| * self.connection.dt | ||||||||||
| ) | ||||||||||
| del source_s, target_x | ||||||||||
|
|
||||||||||
| # Post-synaptic update. | ||||||||||
|
|
@@ -278,10 +288,15 @@ def _connection_update(self, **kwargs) -> None: | |||||||||
| torch.mean(self.average_buffer_post, dim=0) * self.connection.dt | ||||||||||
| ) | ||||||||||
| else: | ||||||||||
| self.feature_value += ( | ||||||||||
| self.reduction(torch.bmm(source_x, target_s), dim=0) | ||||||||||
| * self.connection.dt | ||||||||||
| ) | ||||||||||
| if self.feature_value.is_sparse: | ||||||||||
| self.feature_value += ( | ||||||||||
| torch.bmm(source_x, target_s) * self.connection.dt | ||||||||||
| ).to_sparse() | ||||||||||
|
Comment on lines
+293
to
+294
|
||||||||||
| torch.bmm(source_x, target_s) * self.connection.dt | |
| ).to_sparse() | |
| self.reduction(torch.bmm(source_x, target_s), dim=0).to_sparse() * self.connection.dt | |
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -98,7 +98,10 @@ def update(self) -> None: | |
| (self.connection.wmin != -np.inf).any() | ||
| or (self.connection.wmax != np.inf).any() | ||
| ) and not isinstance(self, NoOp): | ||
| self.connection.w.clamp_(self.connection.wmin, self.connection.wmax) | ||
| if self.connection.w.is_sparse: | ||
| raise Exception("SparseConnection isn't supported for wmin\\wmax") | ||
|
||
| else: | ||
| self.connection.w.clamp_(self.connection.wmin, self.connection.wmax) | ||
|
|
||
|
|
||
| class NoOp(LearningRule): | ||
|
|
@@ -396,7 +399,10 @@ def _connection_update(self, **kwargs) -> None: | |
| if self.nu[0].any(): | ||
| source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float() | ||
| target_x = self.target.x.view(batch_size, -1).unsqueeze(1) * self.nu[0] | ||
| self.connection.w -= self.reduction(torch.bmm(source_s, target_x), dim=0) | ||
| update = self.reduction(torch.bmm(source_s, target_x), dim=0) | ||
| if self.connection.w.is_sparse: | ||
| update = update.to_sparse() | ||
| self.connection.w -= update | ||
| del source_s, target_x | ||
|
|
||
| # Post-synaptic update. | ||
|
|
@@ -405,7 +411,10 @@ def _connection_update(self, **kwargs) -> None: | |
| self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1] | ||
| ) | ||
| source_x = self.source.x.view(batch_size, -1).unsqueeze(2) | ||
| self.connection.w += self.reduction(torch.bmm(source_x, target_s), dim=0) | ||
| update = self.reduction(torch.bmm(source_x, target_s), dim=0) | ||
| if self.connection.w.is_sparse: | ||
| update = update.to_sparse() | ||
| self.connection.w += update | ||
| del source_x, target_s | ||
|
|
||
| super().update() | ||
|
|
@@ -1113,10 +1122,14 @@ def _connection_update(self, **kwargs) -> None: | |
|
|
||
| # Pre-synaptic update. | ||
| update = self.reduction(torch.bmm(source_s, target_x), dim=0) | ||
| if self.connection.w.is_sparse: | ||
| update = update.to_sparse() | ||
| self.connection.w += self.nu[0] * update | ||
|
|
||
| # Post-synaptic update. | ||
| update = self.reduction(torch.bmm(source_x, target_s), dim=0) | ||
| if self.connection.w.is_sparse: | ||
| update = update.to_sparse() | ||
| self.connection.w += self.nu[1] * update | ||
|
|
||
| super().update() | ||
|
|
@@ -1542,8 +1555,10 @@ def _connection_update(self, **kwargs) -> None: | |
| a_minus = torch.tensor(a_minus, device=self.connection.w.device) | ||
|
|
||
| # Compute weight update based on the eligibility value of the past timestep. | ||
| update = reward * self.eligibility | ||
| self.connection.w += self.nu[0] * self.reduction(update, dim=0) | ||
| update = self.reduction(reward * self.eligibility, dim=0) | ||
| if self.connection.w.is_sparse: | ||
| update = update.to_sparse() | ||
| self.connection.w += self.nu[0] * update | ||
|
|
||
| # Update P^+ and P^- values. | ||
| self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) | ||
|
|
@@ -2214,10 +2229,11 @@ def _connection_update(self, **kwargs) -> None: | |
| self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace) | ||
| self.eligibility_trace += self.eligibility / self.tc_e_trace | ||
|
|
||
| update = self.nu[0] * self.connection.dt * reward * self.eligibility_trace | ||
| if self.connection.w.is_sparse: | ||
| update = update.to_sparse() | ||
| # Compute weight update. | ||
| self.connection.w += ( | ||
| self.nu[0] * self.connection.dt * reward * self.eligibility_trace | ||
| ) | ||
| self.connection.w += update | ||
|
|
||
| # Update P^+ and P^- values. | ||
| self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) | ||
|
|
@@ -2936,6 +2952,9 @@ def _connection_update(self, **kwargs) -> None: | |
| ) * source_x[:, None] | ||
|
|
||
| # Compute weight update. | ||
| self.connection.w += self.nu[0] * reward * self.eligibility_trace | ||
| update = self.nu[0] * reward * self.eligibility_trace | ||
| if self.connection.w.is_sparse: | ||
| update = update.to_sparse() | ||
| self.connection.w += update | ||
|
|
||
| super().update() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,11 +4,18 @@ | |
| import torch | ||
| from scipy.spatial.distance import euclidean | ||
| from torch.nn.modules.utils import _pair | ||
| from torch import device | ||
|
||
|
|
||
| from bindsnet.learning import PostPre | ||
| from bindsnet.learning.MCC_learning import PostPre as MMCPostPre | ||
| from bindsnet.network import Network | ||
| from bindsnet.network.nodes import DiehlAndCookNodes, Input, LIFNodes | ||
| from bindsnet.network.topology import Connection, LocalConnection | ||
| from bindsnet.network.topology import ( | ||
| Connection, | ||
| LocalConnection, | ||
| MulticompartmentConnection, | ||
| ) | ||
| from bindsnet.network.topology_features import Weight | ||
|
|
||
|
|
||
| class TwoLayerNetwork(Network): | ||
|
|
@@ -94,6 +101,9 @@ class DiehlAndCook2015(Network): | |
| def __init__( | ||
| self, | ||
| n_inpt: int, | ||
| device: str = "cpu", | ||
| batch_size: int = None, | ||
| sparse: bool = False, | ||
| n_neurons: int = 100, | ||
| exc: float = 22.5, | ||
| inh: float = 17.5, | ||
|
|
@@ -102,6 +112,7 @@ def __init__( | |
| reduction: Optional[callable] = None, | ||
| wmin: float = 0.0, | ||
| wmax: float = 1.0, | ||
| w_dtype: torch.dtype = torch.float32, | ||
| norm: float = 78.4, | ||
| theta_plus: float = 0.05, | ||
| tc_theta_decay: float = 1e7, | ||
|
|
@@ -124,6 +135,7 @@ def __init__( | |
| dimension. | ||
| :param wmin: Minimum allowed weight on input to excitatory synapses. | ||
| :param wmax: Maximum allowed weight on input to excitatory synapses. | ||
| :param w_dtype: Data type for :code:`w` tensor | ||
| :param norm: Input to excitatory layer connection weights normalization | ||
| constant. | ||
| :param theta_plus: On-spike increment of ``DiehlAndCookNodes`` membrane | ||
|
|
@@ -170,27 +182,57 @@ def __init__( | |
|
|
||
| # Connections | ||
| w = 0.3 * torch.rand(self.n_inpt, self.n_neurons) | ||
| input_exc_conn = Connection( | ||
| input_exc_conn = MulticompartmentConnection( | ||
| source=input_layer, | ||
| target=exc_layer, | ||
| w=w, | ||
| update_rule=PostPre, | ||
| nu=nu, | ||
| reduction=reduction, | ||
| wmin=wmin, | ||
| wmax=wmax, | ||
| norm=norm, | ||
| device=device, | ||
| pipeline=[ | ||
| Weight( | ||
| "weight", | ||
| w, | ||
| value_dtype=w_dtype, | ||
| range=[wmin, wmax], | ||
| norm=norm, | ||
| reduction=reduction, | ||
| nu=nu, | ||
| learning_rule=MMCPostPre, | ||
| sparse=sparse, | ||
| batch_size=batch_size, | ||
| ) | ||
| ], | ||
| ) | ||
| w = self.exc * torch.diag(torch.ones(self.n_neurons)) | ||
| exc_inh_conn = Connection( | ||
| source=exc_layer, target=inh_layer, w=w, wmin=0, wmax=self.exc | ||
| if sparse: | ||
| w = w.unsqueeze(0).expand(batch_size, -1, -1) | ||
| exc_inh_conn = MulticompartmentConnection( | ||
| source=exc_layer, | ||
| target=inh_layer, | ||
| device=device, | ||
| pipeline=[ | ||
| Weight( | ||
| "weight", w, value_dtype=w_dtype, range=[0, self.exc], sparse=sparse | ||
| ) | ||
| ], | ||
| ) | ||
| w = -self.inh * ( | ||
| torch.ones(self.n_neurons, self.n_neurons) | ||
| - torch.diag(torch.ones(self.n_neurons)) | ||
| ) | ||
| inh_exc_conn = Connection( | ||
| source=inh_layer, target=exc_layer, w=w, wmin=-self.inh, wmax=0 | ||
| if sparse: | ||
| w = w.unsqueeze(0).expand(batch_size, -1, -1) | ||
| inh_exc_conn = MulticompartmentConnection( | ||
| source=inh_layer, | ||
| target=exc_layer, | ||
| device=device, | ||
| pipeline=[ | ||
| Weight( | ||
| "weight", | ||
| w, | ||
| value_dtype=w_dtype, | ||
| range=[-self.inh, 0], | ||
| sparse=sparse, | ||
| ) | ||
| ], | ||
| ) | ||
|
|
||
| # Add to network | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
torch.tensor(indices)call is unnecessary and may cause performance issues or device mismatches. Theindicesvariable is already a tensor on the same device asspikes. It should be used directly:torch.index_select(spikes, dim=0, index=indices).