Hi All,
I’ve been trying to use jit to speed a calculation I’m using, however, it doesn’t seem to be supported. I’m currently using torch.functional.hessian to calculate the Laplacian (the trace of the hessian) of my network and I thought jit might be able to speed the evaluation. I was wondering if someone might be able to confirm that you can’t jit a hessian function or perhaps I coded in the solution wrong?
Here’s some example code I’m using, to reproduce the error
import torch
import torch.nn as nn
from torch.autograd.functional import hessian
from torch import Tensor
class Model(nn.Module):
def __init__(self, num_inputs, num_hidden, num_outputs):
super(Model, self).__init__()
self.num_inputs = num_inputs
self.num_hidden = num_hidden
self.num_outputs = num_outputs
self.fc1 = nn.Linear(num_inputs, num_hidden, bias=True)
self.fc2 = nn.Linear(num_hidden, num_outputs, bias=True)
self.af1 = nn.Tanh()
def forward(self, x):
x = self.fc1(x)
x = self.af1(x)
x = self.fc2(x)
return x
@torch.jit.script
def calc_laplacian(x: Tensor) -> Tensor:
laplacian = torch.zeros(x.shape[0])
for i, xi in enumerate(x):
hess = torch.autograd.functional.hessian(net,xi.unsqueeze(0), create_graph=False) #calc. hessian
laplacian[i] = torch.diagonal(hess.view(net.num_inputs, net.num_inputs), offset=0).sum() #get trace
return laplacian
net = Model(num_inputs=3, num_hidden=16, num_outputs=1)
x = torch.randn(1000,3) #get some random values to calculate
y = net(x) #output of network
trace_d2y_dx2 = calc_laplacian(x) #calculate laplacian
The error I get is the following,
Traceback (most recent call last):
File "jit_hessian.py", line 27, in <module>
def calc_laplacian(x: Tensor) -> Tensor:
File "/home/user/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1550, in script
fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
File "/home/user/anaconda3/lib/python3.7/site-packages/torch/jit/_recursive.py", line 583, in try_compile_fn
return torch.jit.script(fn, _rcb=rcb)
File "/home/user/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1547, in script
ast = get_jit_def(obj, obj.__name__)
File "/home/user/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 185, in get_jit_def
return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)
File "/home/user/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 219, in build_def
build_stmts(ctx, body))
File "/home/user/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 126, in build_stmts
stmts = [build_stmt(ctx, s) for s in stmts]
File "/home/user/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 126, in <listcomp>
stmts = [build_stmt(ctx, s) for s in stmts]
File "/home/user/anaconda3/lib/python3.7/site-packages/torch/jit/frontend.py", line 192, in __call__
raise UnsupportedNodeError(ctx, node)
torch.jit.frontend.UnsupportedNodeError: function definitions aren't supported:
File "/home/user/anaconda3/lib/python3.7/site-packages/torch/autograd/functional.py", line 538
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian")
def ensure_single_output_function(*inp):
~~~ <--- HERE
out = func(*inp)
is_out_tuple, t_out = _as_tuple(out, "outputs of the user-provided function", "hessian")
I did notice from a similar question here that the listcomp call may be an issue? Could this be the problem?
Many thanks for the help in advance!