I solved it. After receiving the list with gradients, I perform the following:
# grads = [sum of gradients per batch obtained from train_batches]
for index, param in enumerate(model.parameters()):
param.grad = torch.tensor(grads[index])
I solved it. After receiving the list with gradients, I perform the following:
# grads = [sum of gradients per batch obtained from train_batches]
for index, param in enumerate(model.parameters()):
param.grad = torch.tensor(grads[index])