Computing gradient w.r.t the network parameters on each input sample

Hi,

I need to compute the norm of the grads w.r.t. the network parameters on each input sample. Right now I am doing this as follows:

pred = net(input)
for i in range(len(input)):

    loss = criterion(pred[i:i + 1], targets[i:i + 1])  # only consider the loss on i-th data
    net.zero_grad()
    loss.backward(retain_graph=True)

    # gather gradients wrt the weights on the current sample
    grad_l = []
    for param in net.parameters():
        # check if the param has grad
        if param.grad is not None:
            grad_l.append(torch.clone(param.grad.view(-1)))
    grad_l = torch.cat(grad_l, 0)
    norm = grad_l.norm(2)
    grad_w_norm_batch.append(norm.detach())

However, it is notably slow. Is there any other way/trick to do so in a faster manner?