I’m trying to figure out how one can compute the gradient for individual samples in a batched fashion. Specifically, given an input batch, and the score outputs (ex mse for each sample), I want to compute what the gradients are for each item in the batch. I can think of a simple way to do this, which is just running forward/backward passes on individual items in a for loop, but that seems like it would be slow (since you’re looping in python instead of efficiently doing the forward prop on the whole batch). Is there a better way to do this?
Found my own answer! This function does everything I want:
@Giuseppe_Castiglione: could you please elaborate your answer more?
How did you use torch.autograd.grad() to get the gradient of each individual input vector from a batch?
Some code snippet will be appreciated.
One relevant discussion is
Does this one really solve your problem?
I was able to use this function for the same task by summing over the “output” tensor and passing it into the function along with the input tensor as-is. The first tensor in the result tuple contains the same number of batches as the input tensor.