Hello.
Given a function $ R^{3} \mapsto R^{3} $ implemented as point-wise MLPs, what is the most efficient way of computing the Jacobian matrix?
In my case, my input and output are of shape (B, N, 3)
and the desired Jacobian matrix would be of shape (B, N, 3, 3)
as the feedforward function is applied point-wise. I first view
as (B*N, 3)
and call torch.autograd.functional.jacobian
that results in output of shape (B*N, 3, 3, B*N, 3)
. If I understand correctly, the gradients are computed between every point in my input and every point in the output. Rightly so, jacobian.sum(dim=2)
gives me the desired (B*N, 3,3)
as the output. Instead of computing a backward pass between different points (which are 0s), could the jacobian computation be done more efficiently?
I also have come up with an implementation that is faster than torch.autograd.functional.jacobian
as shown below. However, I’m not sure whether the slicing
and cat
operation applied here breaks the gradient. Practically, I observe that my loss doesn’t change at all and hence was curious what could be going wrong.
def compute_jacobian(inp, out):
u = out[:, 0]
v = out[:,1]
w = out[:,2]
grad_outputs = torch.ones_like(u)
grad_u = torch.autograd.grad(u, [inp], grad_outputs=grad_outputs, create_graph=True)[0]
grad_v = torch.autograd.grad(v, [inp], grad_outputs=grad_outputs, create_graph=True)[0]
grad_w = torch.autograd.grad(w, [inp], grad_outputs=grad_outputs, create_graph=True)[0]
return torch.cat((grad_u, grad_v, grad_w), dim=-1).reshape(-1, 3,3)
def compute_jacobian_batch(inp, out):
u = out[:,:,0]
v = out[:,:,1]
w = out[:,:,2]
n_batch = inp.size(0)
grad_outputs = torch.ones_like(u)
grad_u = torch.autograd.grad(u, [inp], grad_outputs=grad_outputs, create_graph=True)[0]
grad_v = torch.autograd.grad(v, [inp], grad_outputs=grad_outputs, create_graph=True)[0]
grad_w = torch.autograd.grad(w, [inp], grad_outputs=grad_outputs, create_graph=True)[0]
jacob_mat = torch.cat((grad_u, grad_v, grad_w), dim=-1).reshape(n_batch, -1, 3,3).contiguous()
jacob_mat.retain_grad()
return jacob_mat
Can someone please verify this implementation ? Thank you!