More efficient norm of gradient computations using vmap

Hi, currently I compute the norms of the gradients w.r.t. model parameters of loss function applied individually per example in the following way:

vmap_loss = torch.vmap(compute_loss_for_single_instance, in_dims=(None, None, 0, 0))
losses = vmap_loss(network, loss_fn, X, y)
norm_gradients = [compute_grad_norm(torch.autograd.grad(loss, network.parameters(), retain_graph=True)).cpu().numpy() for loss in losses]

where I define the auxiliary methods:

def compute_loss_for_single_instance(network, loss_function, image, label):
    y_pred = network(image.unsqueeze(0))
    loss = loss_function(y_pred, label.unsqueeze(0))
    return loss

def compute_grad_norm(grads):
    grads = [param_grad.detach().flatten() for param_grad in grads if param_grad is not None]
    norm = torch.cat(grads).norm()
    return norm

It works, but takes a while.
Is there a way to compute norm_gradients more efficiently?

Many thanks