Per-sample gradients w.r.t the output of a layer

Hi @fy-meng,

Have a look at the torch.func namespace, you can efficiently compute per-sample gradients via torch.func.vmap.

You can find an example I wrote for calculating the per-sample Jacobian over a batch of data here.