Hi,
I need to compute the norm of the grads w.r.t. the network parameters on each input sample. Right now I am doing this as follows:
pred = net(input)
for i in range(len(input)):
loss = criterion(pred[i:i + 1], targets[i:i + 1]) # only consider the loss on i-th data
net.zero_grad()
loss.backward(retain_graph=True)
# gather gradients wrt the weights on the current sample
grad_l = []
for param in net.parameters():
# check if the param has grad
if param.grad is not None:
grad_l.append(torch.clone(param.grad.view(-1)))
grad_l = torch.cat(grad_l, 0)
norm = grad_l.norm(2)
grad_w_norm_batch.append(norm.detach())
However, it is notably slow. Is there any other way/trick to do so in a faster manner?