In the context of a central server and several nodes that all have a similar model with different weights, I compute and back-propagate the loss of a norm on the weights on the central model given all the updated nodes models. I want to compute the gradient norm of the difference between the central model and the nodes for every node separetly. The thing is that I load the weigths of the nodes one by one, compute the backward pass on the general model, and finally perform the update step:
optimizer.zero_grad()
for i, node in enumerate(nodes):
user_modelname = f'user_models/user_{i}.pth'
user_model.load_model(name = user_modelname)
reg = models_diff(simple_model, user_model)
reg.backward()
optimizer.step()
My question is the following: when I check the p.grad of the general model parameters, I get the sum of all the gradients, but I would like the gradients individually of every node and I can’t perform one update at a time since the model difference would change at every iteration, is there a way to get the intermediate gradient values of the backward pass?
But what if I want to still apply the backpropagation? My temporary solution for now is to store the gradient difference between two nodes but this seems inefficient and hacky.
I’m not sure I understand the question correctly. autograd.grad will calculate the gradients as well, but won’t accumulate them in the .grad attributes of the leaf tensors, but will return them instead.
Ok I got it (I didn’t quite understood how torch.autograd.grad worked).
Basically, for every layer that has its requires_grad = True, I can perform:
torch.autograd.grad(reg, layer)
ant it returns the gradient of the regularization on the layer. Thanks a lot.
Just an auxiliary question though: do we agree that this does not update the .grad value of the layer and that I can still call .backward() like normal afterwards?
lin = nn.Linear(1, 1)
x = torch.randn(1, 1)
out = lin(x)
grads = torch.autograd.grad(out, lin.parameters())
print(grads)
> (tensor([[1.3361]]), tensor([1.]))
print([p.grad for p in lin.parameters()])
> [None, None]