I’m trying to figure out how one can compute the gradient for individual samples in a batched fashion. Specifically, given an input batch, and the score outputs (ex mse for each sample), I want to compute what the gradients are for each item in the batch. I can think of a simple way to do this, which is just running forward/backward passes on individual items in a for loop, but that seems like it would be slow (since you’re looping in python instead of efficiently doing the forward prop on the whole batch). Is there a better way to do this?

1 Like

Found my own answer! This function does everything I want:

https://pytorch.org/docs/stable/autograd.html#torch.autograd.grad

@Giuseppe_Castiglione: could you please elaborate your answer more?

How did you use torch.autograd.grad() to get the gradient of each individual input vector from a batch?

Some code snippet will be appreciated.

One relevant discussion is

Does this one really solve your problem?

I was able to use this function for the same task by summing over the “output” tensor and passing it into the function along with the input tensor as-is. The first tensor in the result tuple contains the same number of batches as the input tensor.