Hi, currently I compute the norms of the gradients w.r.t. model parameters of loss function applied individually per example in the following way:

```
vmap_loss = torch.vmap(compute_loss_for_single_instance, in_dims=(None, None, 0, 0))
losses = vmap_loss(network, loss_fn, X, y)
norm_gradients = [compute_grad_norm(torch.autograd.grad(loss, network.parameters(), retain_graph=True)).cpu().numpy() for loss in losses]
```

where I define the auxiliary methods:

```
def compute_loss_for_single_instance(network, loss_function, image, label):
y_pred = network(image.unsqueeze(0))
loss = loss_function(y_pred, label.unsqueeze(0))
return loss
def compute_grad_norm(grads):
grads = [param_grad.detach().flatten() for param_grad in grads if param_grad is not None]
norm = torch.cat(grads).norm()
return norm
```

It works, but takes a while.

Is there a way to compute norm_gradients more efficiently?

Many thanks