Hi,
I am trying to jit/generate TorchScript of functorch operations (jacrev
, hess
, vmap
).
The desiderata is to save the TorchScript to file and load / infer via c++ frontend (inlcuding CUDA).
I have multiple solutions partially working. Hoever, they are not robust to Network Operatiuons extending dense layers. I am wondering if there is a straightforward/recommended way to do this?
Example using torch==2.0.0:
import torch
from torch.func import vmap, jacrev, hessian
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.input_layer = torch.nn.Linear(2, 512)
hidden_layers = []
for i in range(2):
hidden_layers.append(torch.nn.Linear(512, 512))
self.hidden_layer = torch.nn.ModuleList(hidden_layers)
self.out_layer = torch.nn.Linear(512, 1)
def forward(self, x):
x = self.input_layer(x)
for layer in self.hidden_layer:
x = torch.tanh(layer(x))
x = self.out_layer(x)
return x
model = Model()
jac_func = vmap(jacrev(model))
hess_func = vmap(jacrev(model))
The (in my mind) obvious way to try
torch.jit.script(jac_func)
fails with
[...]
File ".../python3.10/inspect.py", line 797, in getfile
raise TypeError('module, class, method, function, traceback, frame, or '
TypeError: module, class, method, function, traceback, frame, or code object was expected, got Model
and
dummy_inp = torch.zeros(1, 2)
torch.jit.trace(jac_func, dummy_inp)
fails with
[...]
RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
Tensor:
0.4287 0.0843
0.6689 -0.2589
[...]
Now, what is working:
from torch.fx.experimental.proxy_tensor import make_fx
dummy_inp = torch.zeros(1, 2)
torch.jit.trace(make_fx(jac_func)(dummy_inp), dummy_inp)
but
torch.jit.scrit(make_fx(jac_func)(dummy_inp))
would still fail with
RuntimeError:
attribute lookup is not defined on builtin:
File "<eval_with_key>.0", line 6
def forward(self, arg0_1):
_param_constant0 = self._param_constant0
t = torch.ops.aten.t.default(_param_constant0); _param_constant0 = None
~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
unsqueeze = torch.ops.aten.unsqueeze.default(arg0_1, 1); arg0_1 = None
view = torch.ops.aten.view.default(unsqueeze, [1, 2]); unsqueeze = None
The working solution (make_fx
+ trace
), however, will fail easily for certain operations in the model (e.g. ‘LayerNorm’):
class ModelLN(torch.nn.Module):
def __init__(self):
super().__init__()
self.input_layer = torch.nn.Linear(2, 512)
hidden_layers = []
for i in range(2):
hidden_layers.append(torch.nn.Linear(512, 512))
self.ln = torch.nn.LayerNorm(512)
self.hidden_layer = torch.nn.ModuleList(hidden_layers)
self.out_layer = torch.nn.Linear(512, 1)
def forward(self, x):
x = self.input_layer(x)
for layer in self.hidden_layer:
x = torch.tanh(layer(x))
x = self.ln(x)
x = self.out_layer(x)
return x
model_ln = ModelLN()
jac_func = vmap(jacrev(model_ln))
dummy_inp = torch.zeros(1, 2)
torch.jit.trace(make_fx(jac_func)(dummy_inp), dummy_inp)
with
RuntimeError: Found an unsupported argument type in the JIT tracer. File a bug report.
I managed to get a TorchScript with LayerNorm for vmap(jacrev(model_ln))
using an adaptation of functorch.compile.ts_compile
(happy to share if it helps), but completely failed to generate a TorchScript for vmap(hessian(model_ln))
inlcuding a LayerNorm.
My question is: Is there a recommended way to generate jit TorchScript of torch.func
operations which I am missing. Am I missing a more robust way to do this?
Appreciate the help!