Hi,
In my use case, I need to compute the diagonal of the Hessian for 3D tensors.
The code below illustrates what I need to compute. It works, however, it does not scale well if one increases the dimension of the tensor.
Are there faster way to achieve the same computation?
I am on torch=1.10.0
.
from time import time
import torch
from torch._vmap_internals import _vmap as vmap
def grad(outputs, inputs, **kwargs):
"""
Compute the trace of the jacobian of a 3D tensor.
grad_{ijk} = d (outputs_{ijk}) / d (inputs_{ijk})
See: https://discuss.pytorch.org/t/jacobian-functional-api-batch-respecting-jacobian/84571/7?u=amerlo94
"""
shape = outputs.shape
bs = shape[0]
n = shape[1] * shape[2]
outputs = outputs.view(bs, n)
outputs = outputs.sum(axis=0)
grad_outputs = torch.eye(n, dtype=outputs.dtype, device=outputs.device)
def get_vjp(v):
return torch.autograd.grad(outputs, inputs, v, **kwargs)[0]
vjp = vmap(get_vjp)
grad = vjp(grad_outputs)
return grad.T.view(shape)
n = 16
b = torch.randn(1, n, n)
def f(x):
return x ** 2 * b
bs = 128
iterations = 1000
x = torch.rand(bs).requires_grad_(True)
x = x.view(-1, 1, 1)
y = f(x)
# Check gradients
jac = grad(y, x, create_graph=True)
hes = grad(jac, x)
assert torch.allclose(jac, 2 * x * b)
assert torch.allclose(hes, 2 * b)
# Time gradients
t0 = time()
for _ in range(iterations):
jac = grad(y, x, create_graph=True)
grad(jac, x)
tottime = time() - t0
time_per_batch = tottime / iterations / bs
print(f"Time per batch: {time_per_batch * 1e6:.2f} us")