RuntimeError: The Tensor returned by the function given to hessian should contain a single element - Is it possible for torch.autograd.functional.hessian to support batches with PyTorch 1.0?

Hi All,

I was wondering if it’s at all possible to compute the Hessian of an nn.Module of the output w.r.t its input for a batch of inputs? I understand this is a limitation of torch.autograd.functional.hessian but is there a way to add batch support to such a function? Like I’ve shown below with the jacobian function?

Any help would be apprecitated!
Thank you! :slight_smile:

import torch

def func(xs):
  return xs.pow(2).sum(dim=-1)

def jacobian(net, xs):
  ys = net(xs)
  return torch.autograd.grad(ys, xs, torch.ones_like(ys))[0]
 
def hessian(net, xs):
  return torch.autograd.functional.hessian(net, xs)
 
xs = torch.randn(1, requires_grad=True) #single example (i.e. no batch)
jacobian(func, xs) #returns tensor([0.6949])
hessian(func, xs) #returns tensor([[2.]])
 
xs = torch.randn(2, 1, requires_grad=True) #batch dim included 
 
jacobian(func, xs)
"""
returns
tensor([[-2.7621],
        [ 0.1642]])
"""
hessian(func, xs) #throws error

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 2, in hessian
  File "~/anaconda3/lib/python3.8/site-packages/torch/autograd/functional.py", line 701, in hessian
    res = jacobian(jac_func, inputs, create_graph=create_graph, strict=strict, vectorize=vectorize)
  File "~/anaconda3/lib/python3.8/site-packages/torch/autograd/functional.py", line 482, in jacobian
    outputs = func(*inputs)
  File "~/anaconda3/lib/python3.8/site-packages/torch/autograd/functional.py", line 697, in jac_func
    jac = jacobian(ensure_single_output_function, inp, create_graph=True)
  File "~/anaconda3/lib/python3.8/site-packages/torch/autograd/functional.py", line 482, in jacobian
    outputs = func(*inputs)
  File "~/anaconda3/lib/python3.8/site-packages/torch/autograd/functional.py", line 692, in ensure_single_output_function
    raise RuntimeError("The Tensor returned by the function given to hessian should contain a single element")
RuntimeError: The Tensor returned by the function given to hessian should contain a single element
1 Like