Hi All,

I was just wondering how `torch.jit.script`

can be used on functions that take `nn.Module`

as an argument as well as `Tensor`

?

For example, I have an example script below which takes the laplacian of a given function, and the `laplacian_jit`

function takes in 2 arguments; the function, `net`

, and the input `x`

(of which we are taking the laplacian). However, when running this it fails with the following error,

```
RuntimeError:
Unknown type name 'nn.Module':
File "test_jit_func_with_module.py", line 31
@torch.jit.script
def laplacian_jit(net: nn.Module, xs: Tensor):
~~~~~~~~~ <--- HERE
xis = [xi.requires_grad_() for xi in xs.flatten(start_dim=1).t()]
xs_flat = torch.stack(xis, dim=1)
```

It seems that JIT doesnâ€™t support passing `nn.Module`

as an argument type? Is there a way to define the type such that I can pass an `nn.Module`

type object into the jitted-function?

The example script is below,

Any help will be greatly appreciated!

Thank you!

```
import torch
import torch.nn as nn
from typing import List, Optional
from torch import Tensor
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
def forward(self, x):
return x.pow(2).sum(dim=-1)
net = Net()
@torch.jit.script
def sumit(inp: List[Optional[torch.Tensor]]):
elt = inp[0]
if elt is None:
raise RuntimeError("blah")
base = elt
for i in range(1, len(inp)):
next_elt = inp[i]
if next_elt is None:
raise RuntimeError("blah")
base = base + next_elt
return base
@torch.jit.script
def laplacian_jit(net: nn.Module, xs: Tensor):
xis = [xi.requires_grad_() for xi in xs.flatten(start_dim=1).t()]
xs_flat = torch.stack(xis, dim=1)
ys = net(xs_flat.view_as(xs))
ones = torch.ones_like(ys)
grad_outputs = torch.jit.annotate(List[Optional[Tensor]], [])
grad_outputs.append(ones)
result = torch.autograd.grad([ys], [xs_flat], grad_outputs, retain_graph=True, create_graph=True)
dy_dxs = result[0]
if dy_dxs is None:
raise RuntimeError("blah")
generator_as_list = [dy_dxs[..., i] for i in range(len(xis))]
lap_ys_components = [torch.autograd.grad([dy_dxi], [xi], grad_outputs, retain_graph=True, create_graph=False)[0] \
for xi, dy_dxi in zip(xis,generator_as_list)]
lap_ys = sumit(lap_ys_components)
return lap_ys
x = torch.randn(4096,2)
laplacian_jit(x)
```