Different results with same code

Hello! I am building a spiking neural network with surrogate gradient learning according to the tutorial of fzenke.net with some little adaptations but running into problems during model inference.

The model is not a PyTorch NN module, which is why I cannot use the normal .eval() and state_dict() functions for inference and saving. Instead, I save the learned weights of the model in a .pt file using torch.save().

My problem occurs when I load the weights and make predictions with the same code (just without training), where I get different model performances at every restart of the Python kernel (I use Google Colab). In theory, all operations should be deterministic given the same input and weights, which is why I am confused about this behaviour. Does someone maybe have an idea what might cause the problem?

Here are the relevant code snippets (imports excluded):

Initialising weights in a one-layer network (only one weight matrix)

random.seed(123)
torch.manual_seed(123)
np.random.seed(123)

weight_scale = 0.2

w1 = torch.empty((nb_inputs, nb_outputs),  device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(w1, mean=0.0, std=weight_scale/np.sqrt(nb_inputs))

Saving and loading the parameters:

torch.save(w1, 'path/onelayersnn.pt')
w1 = torch.load('path/onelayersnn.pt')

Code that runs the network on some input and the predict function (alpha and beta are float constants)

def run_snn(inputs, batch_size, syn, mem):
    """
    Runs the SNN for 1 batch within one epoch
    :param inputs: spiking input to the network
    :param batch_size: batch size to be run
    :param syn: last synaptic current value
    :param mem: last membrane potential value
    :returns: membrance potentials and output spikes
    """
    # initialize synaptic currents and membrane potentials as the former potential --> no reset of state variables
    syn_here = syn
    mem_here = mem 

    # lists to record membrane potential and output spikes in the simulation time
    mem_rec = []
    spk_rec = []

    # Compute hidden layer activity
    out = torch.zeros((batch_size, nb_outputs), device=device, dtype=dtype) # initialization
    # multiplication of input spikes with the weight matrix, this will be fed to the synaptic variable syn and the membrane potential mem
    h1 = torch.einsum("abc,cd->abd", (inputs, w1)) 
    #loop over time
    for t in range(nb_steps):
        mthr = mem_here-1.0 # subtract the threshold to see if the neurons spike
        out = spike_fn(mthr) # get the layer spiking activity
        rst = out.detach() # We do not want to backprop through the reset

        new_syn = alpha*syn_here +h1[:,t] # calculate new input current for the next timestep of the synapsis (PSP?)
        new_mem =(beta*mem_here +syn_here)*(1.0-rst) # calculate new membrane potential for the timestep 

        mem_rec.append(mem_here) #record the membrane potential
        spk_rec.append(out) #record the spikes
        
        mem_here = new_mem #set new membrane potential at this timestep
        syn_here = new_syn #set the new synaptic current at this timestep

    # last synaptic current and membrane potential
    last_syn = syn_here
    last_mem = mem_here
    # merge the recorded membrane potentials into single tensor
    mem_rec = torch.stack(mem_rec,dim=1)
    # merge output spikes into single tensor
    spk_rec = torch.stack(spk_rec,dim=1)

    return mem_rec, spk_rec, last_syn, last_mem


def predict(x_data, y_data):
    """
    Predicts the class of the input data based on maximum membrane potential
    :param x_data: X
    :returns: y_pred
    """
    syn = torch.zeros((len(y_data),nb_outputs), device=device, dtype=dtype)
    mem = torch.zeros((len(y_data),nb_outputs), device=device, dtype=dtype)
    for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, len(y_data), nb_steps, nb_inputs, max_time, shuffle=False):
      output,_,_,_ = run_snn(x_local.to_dense(), len(y_data), syn, mem)
      m,_= torch.max(output,1) # max over time
      _,am=torch.max(m,1)      # argmax over output units 
      preds = am.float()
    return preds.cpu().numpy()

Thank you very much!