torch.jit.frontend.UnsupportedNodeError: function definitions aren't supported:

Hi All,

I’ve been trying to use jit to speed a calculation I’m using, however, it doesn’t seem to be supported. I’m currently using torch.functional.hessian to calculate the Laplacian (the trace of the hessian) of my network and I thought jit might be able to speed the evaluation. I was wondering if someone might be able to confirm that you can’t jit a hessian function or perhaps I coded in the solution wrong?

Here’s some example code I’m using, to reproduce the error

import torch
import torch.nn as nn
from torch.autograd.functional import hessian
from torch import Tensor

class Model(nn.Module):

  def __init__(self, num_inputs, num_hidden, num_outputs):
    super(Model, self).__init__()
    
    self.num_inputs = num_inputs
    self.num_hidden = num_hidden
    self.num_outputs = num_outputs
    
    self.fc1 = nn.Linear(num_inputs, num_hidden, bias=True)
    self.fc2 = nn.Linear(num_hidden, num_outputs, bias=True)
    
    self.af1 = nn.Tanh()
    
  def forward(self, x):
    x = self.fc1(x)
    x = self.af1(x)
    x = self.fc2(x)
    return x
 
@torch.jit.script
def calc_laplacian(x: Tensor) -> Tensor:
  laplacian = torch.zeros(x.shape[0])
  for i, xi in enumerate(x):
    hess = torch.autograd.functional.hessian(net,xi.unsqueeze(0), create_graph=False) #calc. hessian
    laplacian[i] = torch.diagonal(hess.view(net.num_inputs, net.num_inputs), offset=0).sum() #get trace

  return laplacian
  
net = Model(num_inputs=3, num_hidden=16, num_outputs=1)

x = torch.randn(1000,3) #get some random values to calculate
y = net(x)                       #output of network 

trace_d2y_dx2 = calc_laplacian(x) #calculate laplacian

The error I get is the following,


Traceback (most recent call last):
  File "jit_hessian.py", line 27, in <module>
    def calc_laplacian(x: Tensor) -> Tensor:
  File "/home/user/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1550, in script
    fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
  File "/home/user/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 583, in try_compile_fn
    return torch.jit.script(fn, _rcb=rcb)
  File "/home/user/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1547, in script
    ast = get_jit_def(obj, obj.__name__)
  File "/home/user/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 185, in get_jit_def
    return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)
  File "/home/user/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 219, in build_def
    build_stmts(ctx, body))
  File "/home/user/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 126, in build_stmts
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/home/user/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 126, in <listcomp>
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/home/user/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 192, in __call__
    raise UnsupportedNodeError(ctx, node)
torch.jit.frontend.UnsupportedNodeError: function definitions aren't supported:
  File "/home/user/anaconda3/lib/python3.7/site-packages/torch/autograd/functional.py", line 538
    is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian")

    def ensure_single_output_function(*inp):
    ~~~ <--- HERE
        out = func(*inp)
        is_out_tuple, t_out = _as_tuple(out, "outputs of the user-provided function", "hessian")

I did notice from a similar question here that the listcomp call may be an issue? Could this be the problem?

Many thanks for the help in advance!

1 Like