Hello! I am building a spiking neural network with surrogate gradient learning according to the tutorial of fzenke.net with some little adaptations but running into problems during model inference.
The model is not a PyTorch NN module, which is why I cannot use the normal .eval() and state_dict() functions for inference and saving. Instead, I save the learned weights of the model in a .pt file using torch.save().
My problem occurs when I load the weights and make predictions with the same code (just without training), where I get different model performances at every restart of the Python kernel (I use Google Colab). In theory, all operations should be deterministic given the same input and weights, which is why I am confused about this behaviour. Does someone maybe have an idea what might cause the problem?
Here are the relevant code snippets (imports excluded):
Initialising weights in a one-layer network (only one weight matrix)
random.seed(123)
torch.manual_seed(123)
np.random.seed(123)
weight_scale = 0.2
w1 = torch.empty((nb_inputs, nb_outputs), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(w1, mean=0.0, std=weight_scale/np.sqrt(nb_inputs))
Saving and loading the parameters:
torch.save(w1, 'path/onelayersnn.pt')
w1 = torch.load('path/onelayersnn.pt')
Code that runs the network on some input and the predict function (alpha and beta are float constants)
def run_snn(inputs, batch_size, syn, mem):
"""
Runs the SNN for 1 batch within one epoch
:param inputs: spiking input to the network
:param batch_size: batch size to be run
:param syn: last synaptic current value
:param mem: last membrane potential value
:returns: membrance potentials and output spikes
"""
# initialize synaptic currents and membrane potentials as the former potential --> no reset of state variables
syn_here = syn
mem_here = mem
# lists to record membrane potential and output spikes in the simulation time
mem_rec = []
spk_rec = []
# Compute hidden layer activity
out = torch.zeros((batch_size, nb_outputs), device=device, dtype=dtype) # initialization
# multiplication of input spikes with the weight matrix, this will be fed to the synaptic variable syn and the membrane potential mem
h1 = torch.einsum("abc,cd->abd", (inputs, w1))
#loop over time
for t in range(nb_steps):
mthr = mem_here-1.0 # subtract the threshold to see if the neurons spike
out = spike_fn(mthr) # get the layer spiking activity
rst = out.detach() # We do not want to backprop through the reset
new_syn = alpha*syn_here +h1[:,t] # calculate new input current for the next timestep of the synapsis (PSP?)
new_mem =(beta*mem_here +syn_here)*(1.0-rst) # calculate new membrane potential for the timestep
mem_rec.append(mem_here) #record the membrane potential
spk_rec.append(out) #record the spikes
mem_here = new_mem #set new membrane potential at this timestep
syn_here = new_syn #set the new synaptic current at this timestep
# last synaptic current and membrane potential
last_syn = syn_here
last_mem = mem_here
# merge the recorded membrane potentials into single tensor
mem_rec = torch.stack(mem_rec,dim=1)
# merge output spikes into single tensor
spk_rec = torch.stack(spk_rec,dim=1)
return mem_rec, spk_rec, last_syn, last_mem
def predict(x_data, y_data):
"""
Predicts the class of the input data based on maximum membrane potential
:param x_data: X
:returns: y_pred
"""
syn = torch.zeros((len(y_data),nb_outputs), device=device, dtype=dtype)
mem = torch.zeros((len(y_data),nb_outputs), device=device, dtype=dtype)
for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, len(y_data), nb_steps, nb_inputs, max_time, shuffle=False):
output,_,_,_ = run_snn(x_local.to_dense(), len(y_data), syn, mem)
m,_= torch.max(output,1) # max over time
_,am=torch.max(m,1) # argmax over output units
preds = am.float()
return preds.cpu().numpy()
Thank you very much!