Weights stop updating after manual update?

I have a parallelized training setup where I have instances of the same model on different machines. After a set of batches goes through the model, I want to manually set each parameter’s weights to be the average of the all instances of that parameter’s weights.

Following other posts about manually updating weights, my training loop looks like this:

for i_batch, sample_batched in enumerate(train_dataloader):
    optimizer.zero_grad()
    output = model(sample_batched['data'])
    loss = loss_function(output, sample_batched['label']) 
    loss.backward()
    optimizer.step()


nets = [...]  # List of instances of the same model loaded from saved state_dicts.
net_params = [x.named_parameters() for x in nets]

with torch.no_grad():
    for ((name, param), *q) in zip(model.named_parameters(), *net_params):
        param_vals = np.array([p[1].data.numpy() for p in q])
        average_param_val = np.mean(param_vals, axis=0)
        values = torch.ones(param.shape) * average_param_val
        param.copy_(values)

After this point my training loop continues with more input data, but I notice that the training and validation accuracy no longer change, which to me indicates that the weights are no longer being updated by the optimizer after I do the first manual update.

Does anything seem off with how I am updating the values of the parameter weights that would cause them to stop being updated by the optimizer?

I think your loop is called only once as iterating a generator would exhaust it and you would need to refresh it or are you re-creating net_params before each update?
Add a print statement to the custom update loop and make sure the parameters are indeed being updated.

Thanks for taking a look, this loop happens within a larger loop that instantiates a dataloader for each file containing my data.

The problem was actually that I needed all instances of my network to start with the same weight values for averaging the parameters to work as part of a distributed training setup. So in my case, setting the random seed at instantiation solved this.