I have a parallelized training setup where I have instances of the same model on different machines. After a set of batches goes through the model, I want to manually set each parameter’s weights to be the average of the all instances of that parameter’s weights.
Following other posts about manually updating weights, my training loop looks like this:
for i_batch, sample_batched in enumerate(train_dataloader):
optimizer.zero_grad()
output = model(sample_batched['data'])
loss = loss_function(output, sample_batched['label'])
loss.backward()
optimizer.step()
nets = [...] # List of instances of the same model loaded from saved state_dicts.
net_params = [x.named_parameters() for x in nets]
with torch.no_grad():
for ((name, param), *q) in zip(model.named_parameters(), *net_params):
param_vals = np.array([p[1].data.numpy() for p in q])
average_param_val = np.mean(param_vals, axis=0)
values = torch.ones(param.shape) * average_param_val
param.copy_(values)
After this point my training loop continues with more input data, but I notice that the training and validation accuracy no longer change, which to me indicates that the weights are no longer being updated by the optimizer after I do the first manual update.
Does anything seem off with how I am updating the values of the parameter weights that would cause them to stop being updated by the optimizer?