Hello! I am building a spiking neural network with surrogate gradient learning according to the tutorial of fzenke.net with some little adaptations but running into problems during model inference.
The model is not a PyTorch NN module, which is why I cannot use the normal .eval() and state_dict() functions for inference and saving. Instead, I save the learned weights of the model in a .pt file using torch.save().
My problem occurs when I load the weights and make predictions with the same code (just without training), where I get different model performances at every restart of the Python kernel (I use Google Colab). In theory, all operations should be deterministic given the same input and weights, which is why I am confused about this behaviour. Does someone maybe have an idea what might cause the problem?
Here are the relevant code snippets (imports excluded):
Initialising weights in a one-layer network (only one weight matrix)
random.seed(123) torch.manual_seed(123) np.random.seed(123) weight_scale = 0.2 w1 = torch.empty((nb_inputs, nb_outputs), device=device, dtype=dtype, requires_grad=True) torch.nn.init.normal_(w1, mean=0.0, std=weight_scale/np.sqrt(nb_inputs))
Saving and loading the parameters:
torch.save(w1, 'path/onelayersnn.pt') w1 = torch.load('path/onelayersnn.pt')
Code that runs the network on some input and the predict function (alpha and beta are float constants)
def run_snn(inputs, batch_size, syn, mem): """ Runs the SNN for 1 batch within one epoch :param inputs: spiking input to the network :param batch_size: batch size to be run :param syn: last synaptic current value :param mem: last membrane potential value :returns: membrance potentials and output spikes """ # initialize synaptic currents and membrane potentials as the former potential --> no reset of state variables syn_here = syn mem_here = mem # lists to record membrane potential and output spikes in the simulation time mem_rec =  spk_rec =  # Compute hidden layer activity out = torch.zeros((batch_size, nb_outputs), device=device, dtype=dtype) # initialization # multiplication of input spikes with the weight matrix, this will be fed to the synaptic variable syn and the membrane potential mem h1 = torch.einsum("abc,cd->abd", (inputs, w1)) #loop over time for t in range(nb_steps): mthr = mem_here-1.0 # subtract the threshold to see if the neurons spike out = spike_fn(mthr) # get the layer spiking activity rst = out.detach() # We do not want to backprop through the reset new_syn = alpha*syn_here +h1[:,t] # calculate new input current for the next timestep of the synapsis (PSP?) new_mem =(beta*mem_here +syn_here)*(1.0-rst) # calculate new membrane potential for the timestep mem_rec.append(mem_here) #record the membrane potential spk_rec.append(out) #record the spikes mem_here = new_mem #set new membrane potential at this timestep syn_here = new_syn #set the new synaptic current at this timestep # last synaptic current and membrane potential last_syn = syn_here last_mem = mem_here # merge the recorded membrane potentials into single tensor mem_rec = torch.stack(mem_rec,dim=1) # merge output spikes into single tensor spk_rec = torch.stack(spk_rec,dim=1) return mem_rec, spk_rec, last_syn, last_mem def predict(x_data, y_data): """ Predicts the class of the input data based on maximum membrane potential :param x_data: X :returns: y_pred """ syn = torch.zeros((len(y_data),nb_outputs), device=device, dtype=dtype) mem = torch.zeros((len(y_data),nb_outputs), device=device, dtype=dtype) for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, len(y_data), nb_steps, nb_inputs, max_time, shuffle=False): output,_,_,_ = run_snn(x_local.to_dense(), len(y_data), syn, mem) m,_= torch.max(output,1) # max over time _,am=torch.max(m,1) # argmax over output units preds = am.float() return preds.cpu().numpy()
Thank you very much!