Assume we have a batch of data. Given each data point in the batch, I would like to get the norm of the gradient of the output of the network w.r.t the network parameters.
A naive solution is to perform multiple forward and backward passes (one pass for one data point), but this is pretty slow. Is there a better solution?
A better solution seems possible, but I don’t see any mention of it in Pytorch documentation.