JIT Script of functorch

Hi,

I am trying to jit/generate TorchScript of functorch operations (jacrev, hess, vmap).

The desiderata is to save the TorchScript to file and load / infer via c++ frontend (inlcuding CUDA).

I have multiple solutions partially working. Hoever, they are not robust to Network Operatiuons extending dense layers. I am wondering if there is a straightforward/recommended way to do this?

Example using torch==2.0.0:

import torch
from torch.func import vmap, jacrev, hessian

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.input_layer = torch.nn.Linear(2, 512)

        hidden_layers = []
        for i in range(2):
            hidden_layers.append(torch.nn.Linear(512, 512))

        self.hidden_layer = torch.nn.ModuleList(hidden_layers)
        self.out_layer = torch.nn.Linear(512, 1)

    def forward(self, x):
        x = self.input_layer(x)
        for layer in self.hidden_layer:
            x = torch.tanh(layer(x))
        x = self.out_layer(x)
        return x

model = Model()
jac_func = vmap(jacrev(model))
hess_func = vmap(jacrev(model))

The (in my mind) obvious way to try

torch.jit.script(jac_func)

fails with

[...]
  File ".../python3.10/inspect.py", line 797, in getfile
    raise TypeError('module, class, method, function, traceback, frame, or '
TypeError: module, class, method, function, traceback, frame, or code object was expected, got Model

and

dummy_inp = torch.zeros(1, 2)
torch.jit.trace(jac_func, dummy_inp)

fails with

[...]
RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
Tensor:
 0.4287  0.0843
 0.6689 -0.2589
[...]

Now, what is working:

from torch.fx.experimental.proxy_tensor import make_fx
dummy_inp = torch.zeros(1, 2)
torch.jit.trace(make_fx(jac_func)(dummy_inp), dummy_inp)

but

torch.jit.scrit(make_fx(jac_func)(dummy_inp))

would still fail with

RuntimeError: 
attribute lookup is not defined on builtin:
  File "<eval_with_key>.0", line 6
def forward(self, arg0_1):
    _param_constant0 = self._param_constant0
    t = torch.ops.aten.t.default(_param_constant0);  _param_constant0 = None
        ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    unsqueeze = torch.ops.aten.unsqueeze.default(arg0_1, 1);  arg0_1 = None
    view = torch.ops.aten.view.default(unsqueeze, [1, 2]);  unsqueeze = None

The working solution (make_fx + trace), however, will fail easily for certain operations in the model (e.g. ‘LayerNorm’):

class ModelLN(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.input_layer = torch.nn.Linear(2, 512)

        hidden_layers = []
        for i in range(2):
            hidden_layers.append(torch.nn.Linear(512, 512))

        self.ln = torch.nn.LayerNorm(512)
        self.hidden_layer = torch.nn.ModuleList(hidden_layers)
        self.out_layer = torch.nn.Linear(512, 1)

    def forward(self, x):
        x = self.input_layer(x)
        for layer in self.hidden_layer:
            x = torch.tanh(layer(x))
        x = self.ln(x)
        x = self.out_layer(x)
        return x

model_ln = ModelLN()
jac_func = vmap(jacrev(model_ln))

dummy_inp = torch.zeros(1, 2)
torch.jit.trace(make_fx(jac_func)(dummy_inp), dummy_inp)

with

RuntimeError: Found an unsupported argument type in the JIT tracer. File a bug report.

I managed to get a TorchScript with LayerNorm for vmap(jacrev(model_ln)) using an adaptation of functorch.compile.ts_compile (happy to share if it helps), but completely failed to generate a TorchScript for vmap(hessian(model_ln)) inlcuding a LayerNorm.

My question is: Is there a recommended way to generate jit TorchScript of torch.func operations which I am missing. Am I missing a more robust way to do this?

Appreciate the help!

Anyone? Giving this a bump for visibility.
Thanks