Per-sample gradients w.r.t the output of a layer

Hi guys,

I’m trying to run per-sample TCAV, which requires computing the gradients w.r.t. the output of a layer (e.g. resnet18.layer4). I could set batch size to 1, but it seems that the script would become extremely inefficient and slow.

I’ve also looked at torch.nn.utils._per_sample_grad and autograd_hacks, but from my understanding they only store per-sample gradients for each layer with weights, not layer outputs, so they won’t work for resnet18.layer4.

Is there an efficient way to compute such gradients?

Hi @fy-meng,

Have a look at the torch.func namespace, you can efficiently compute per-sample gradients via torch.func.vmap.

You can find an example I wrote for calculating the per-sample Jacobian over a batch of data here.