I’d like to compute per-sample gradients given per-sample loss values and a batch of inputs. For example:
batch_size = 7
num_features = 12
x = torch.randn((batch_size, num_features)).requires_grad_()
y = x * 3.0
label = torch.randn((batch_size, num_features))
loss = torch.mean((label - y).pow(2), dim=1)
grad1 = ps_grad(loss, x, create_graph=False)
So I’d like to implement the ps_grad function above. A naive implementation would be:
def ps_grad(outputs: torch.Tensor, inputs: torch.Tensor, create_graph: bool = False):
batch_size = inputs.size(0)
grads = [autograd.grad(outputs[[i]], inputs[[i]], create_graph=create_graph)[0] for i in range(batch_size)]
return torch.cat(grads, dim=1)
However, this errors out with:
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
This is presumably because individual input tensor elements did not participate in the calculation of the loss.
The example here does the equivalent of calculating the loss inside the ps_grad function, but that’s not an option.
I’ve also tried using vmap, but the problem is that it creates a BatchedTensor internally, which loses the requires_grad state of the inputs.
I’m not too concerned with performance at this point. Any ideas?
To answer my own question: For most/all practical purposes, this is a valid implementation of ps_grad:
def ps_grad(outputs: torch.Tensor, inputs: torch.Tensor, create_graph: bool = False):
return autograd.grad(outputs.sum(), inputs, create_graph=create_graph)[0]
Hi @jose-solorzano,
So why exactly doesn’t torch.func
work for your example? Could you share what you did with torch.func
?
Also, you ideally shouldn’t mix torch.autograd.grad
and torch.func
operations as they compute the gradients in different ways. The .grad
attribute field accumulates gradients so they have the same shape as the parameter Tensor they’re attributed to.
In the case of torch.func
, it creates a functionalized form of your loss and computes per-sample gradients efficiently via removing the batch dimension. In short, it defines the function for 1 sample then vectorizes over all samples. This effectively ignores the .grad
attribute entirely as torch.func
only sees pytorch primitives. You could potentially attribute the per-sample gradients to the .grad
field of your parameter, but you’d get a mis-match error on the shape.
I mean something like this:
def ps_grad(outputs: torch.Tensor, inputs: torch.Tensor, create_graph: bool = False):
def _singleton_grad(_out: torch.Tensor, _in: torch.Tensor):
return autograd.grad(_out, _in, create_graph=create_graph)[0]
return vmap(_singleton_grad)(outputs, inputs)
Errors out with:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
I’m curious if there’s a right way to do that (though it doesn’t matter now, because I have a satisfactory solution for my use case.)
The example you show there wouldn’t work as you’re mixing torch.autograd.grad
with torch.func.vmap
, the torch.autograd
and torch.func
namespaces compute gradients in different ways and that leads to the .grad
doesn’t exist error you stated above.
What you could do is replace torch.autograd.grad
with torch.func.grad
and compute per-sample gradients and that would be compatible with the torch.func
namespace.
1 Like