I am using dask to perform distributed training with Pytorch. What I am doing is to create chunks of the data I want to train, and compute the loss function and call loss.backward()
in each worker with each chunk. This seems to be working (.forward()
and .backward()
are doing their job). What I need to do now is to accumulate the gradients and sum them up as described here. After being summed up, I need to set them to the model to do an optimizer.update()
.
The delayed function that is computing the loss
and loss.backward()
is shown below:
@dask.delayed
def train_batches(index, chunk, targets, model, optimizer, lossfxn,
atoms_per_image, device):
"""A function that allows training per batches"""
inputs = OrderedDict(chunk)
outputs = model(inputs)
if lossfxn is None:
loss = MSELoss(outputs, targets[index], optimizer,
atoms_per_image[index], device=device)
loss.backward()
else:
raise('I do not know what to do')
gradients = []
for param in model.parameters():
gradients.append(param.grad)
return outputs, loss, gradients
Am I getting correctly the gradients by doing the loop in model.parameters()
? Should I do params.grad.data
?. I do get some lists when I call this function and I assume they are the gradients.
Now, I would like to sum them and set this sum as the gradient of the model to then call optimizer.step()
. How can I achieve that?
I would appreciate any suggestions .