I would like to use torch.autograd.grad to compute the Jacobian matrix for a batch of input.
Say I have a function y=f(x)
where y
has the same dimension as x
. If I have a batch of input x
, denoted as x₁, x₂, …, xₙ, where the batch size is n
, I want to compute the quantity
∑ᵢ trace(∂f(xᵢ)/∂xᵢ)
Namely for each n x n Jacobian matrix ∂f(xᵢ)/∂xᵢ, I want to compute the trace of this matrix, and then sum up over the data index i
.
Currently I use torch.autograd.grad to compute it in a for loop
import torch
def foo(f: torch.nn.Module, xi: torch.Tensor):
"""
Compute trace(∂f(xᵢ)/∂xᵢ)
Here xi is a single input data, not a batch or data.
"""
xi.requires_grad = True
x_dim = xi.numel()
return torch.autograd.grad(f(xi), xi, grad_outputs = (torch.eye(x_dim),), retain_graph=True, create_graph=True, is_grads_batched=True)[0].trace()
if __name__ == "__main__":
x_batch = torch.tensor([[1, 2.], [3, 4.], [5, 6.]])
f = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.ReLU(), torch.nn.Linear(3, 2))
trace_sum = 0
for i in range(x_batch.shape[0]):
xi = x_batch[i].clone()
xi.requires_grad = True
trace_sum += foo(f, xi)
Is there a way that I don’t need to loop over the input batch dimension through the for loop? Thanks a lot for your kind help in advance!
P.S I also figured that the desired quantity ∑ᵢ trace(∂f(xᵢ)/∂xᵢ) is just
∑ᵢ ∑ⱼ∂f(xᵢ)[j]/∂xᵢ[j]
So should be able to do this
def foo_batch(f: torch.nn.Module, x):
"""
Compute ∑ᵢ trace(∂f(xᵢ)/∂xᵢ)
x is a batch of inputs.
"""
x.requires_grad = True
y = f(x)
y_flat = y.reshape((-1,))
x_flat = x.reshape((-1,))
# This is memory inefficient as it requires a huge identity matrix as grad_outputs.
return torch.autograd.grad(y_flat, x_flat, grad_outputs=(torch.eye(x_flat.numel()),), retain_graph=True, create_graph=True, is_grads_batched=True)[0].trace()
if __name__ == "__main__":
x_batch = torch.tensor([[1, 2.], [3, 4.], [5, 6.]])
f = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.ReLU(), torch.nn.Linear(3, 2))
print(foo_batch(f, x_batch))
But when I run this code, I got the runtime error
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior
I am not sure what this error means, any help would be appreciated.