Hi @ptrblck …thank you for your response
As much as I could , i truncated my code to be small here
import os
import torch
import argparse
import numpy as np
from time import time as t
from torchvision import transforms
from bindsnet.datasets import MNIST
from bindsnet.network import Network, load
from bindsnet.network.topology import Connection
from bindsnet.network.nodes import RealInput, IFNodes
# Paths.
ROOT_DIR ='/home/Users/Ati/Desktop/GD'
data_path = os.path.join(ROOT_DIR, 'data', 'MNIST')
def main(seed=0, n_train=60000, n_test=10000, time=50, lr=0.01, lr_decay=0.95,
update_interval=500, max_prob=1.0, plot=False, train=True, gpu=False):
np.random.seed(seed)
criterion = torch.nn.CrossEntropyLoss()
n_examples = n_train if train else n_test
if train:
# Network building.
network = Network()
# Groups of neurons.
input_layer = RealInput(n=784, sum_input=True)
output_layer = IFNodes(n=10, sum_input=True)
network.add_layer(input_layer, name='X')
network.add_layer(output_layer, name='Y')
# Connections between groups of neurons.
input_connection = Connection(source=input_layer, target=output_layer, norm=150, wmin=-1, wmax=1)
network.add_connection(input_connection, source='X', target='Y')
root = data_path
dataset = MNIST(root='root', download=True )
images, labels = dataset.data ,dataset.targets
images, labels = images.view(-1, 784) / 255, labels
grads = {}
accuracies = []
predictions = []
losses = torch.zeros(update_interval)
correct = torch.zeros(update_interval)
# Run training.
start = t()
for i in range(n_examples):
label = torch.Tensor([labels[i % len(labels)]]).long()
image = images[i % len(labels)]
# Run simulation for single datum.
inpts = {
'X': image.repeat(time, 1)}
network.run(inpts=inpts, time=time)
# Retrieve spikes and summed inputs from both layers.
summed_inputs = {l: network.layers[l].summed for l in network.layers}
# Compute softmax of output spiking activity and get predicted label.
output = summed_inputs['Y'].softmax(0).view(1, -1)
predicted = output.argmax(1).item()
correct[i % update_interval] = int(predicted == label[0].item())
predictions.append(predicted)
# Compute cross-entropy loss between output and true label.
losses[i % update_interval] = criterion(output, label)
if train:
# Compute gradient of the loss WRT average firing rates.
grads['dl/df'] = summed_inputs['Y'].softmax(0)
grads['dl/df'][label] -= 1
# Compute gradient of the summed voltages WRT connection weights.
# This is an approximation; the summed voltages are not a
# smooth function of the connection weights.
grads['dl/dw'] = torch.ger(summed_inputs['X'], grads['dl/df'])
grads['dl/db'] = grads['dl/df']
# Do stochastic gradient descent calculation.
network.connections['X', 'Y'].w -= lr * grads['dl/dw']
network.connections['Y_b', 'Y'].w -= lr * grads['dl/db']
if i > 0 and i % update_interval == 0:
accuracies.append(correct.mean() * 100)
# Decay learning rate.
lr *= lr_decay
start = t()
network.reset_() # Reset state variables.
accuracies.append(correct.mean() * 100)
if train:
lr *= lr_decay
for c in network.connections:
network.connections[c].update_rule.weight_decay *= lr_decay
if __name__ == '__main__':
# Parameters.
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--n_train', type=int, default=60000, help='no. of training samples')
parser.add_argument('--time', default=25, type=int, help='simulation time')
parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
parser.add_argument('--lr_decay', default=0.95, type=float, help='learning rate decay')
parser.add_argument('--update_interval', default=500, type=int, help='no. examples between evaluation')
parser.add_argument('--max_prob', default=1.0, type=float, help='maximum prob. of input spikes')
parser.add_argument('--train', dest='train', action='store_true', help='train phase')
parser.set_defaults(plot=False, train=True)
args = parser.parse_args()
seed = args.seed
n_train = args.n_train
time = args.time
lr = args.lr
lr_decay = args.lr_decay
update_interval = args.update_interval
max_prob = args.max_prob
train = args.train
args = vars(args)
main(seed=seed, n_train=n_train, time=time, lr=lr, lr_decay=lr_decay,
update_interval=update_interval, max_prob=max_prob, train=train)
Just i don’t know that it’s needed to install bindsnet package…but it’s just for creating a network with two layers and creating connections between layers
If there is needed more explanations for clarification , please tell me i would do:)