Hi All,
I was just wondering how torch.jit.script
can be used on functions that take nn.Module
as an argument as well as Tensor
?
For example, I have an example script below which takes the laplacian of a given function, and the laplacian_jit
function takes in 2 arguments; the function, net
, and the input x
(of which we are taking the laplacian). However, when running this it fails with the following error,
RuntimeError:
Unknown type name 'nn.Module':
File "test_jit_func_with_module.py", line 31
@torch.jit.script
def laplacian_jit(net: nn.Module, xs: Tensor):
~~~~~~~~~ <--- HERE
xis = [xi.requires_grad_() for xi in xs.flatten(start_dim=1).t()]
xs_flat = torch.stack(xis, dim=1)
It seems that JIT doesn’t support passing nn.Module
as an argument type? Is there a way to define the type such that I can pass an nn.Module
type object into the jitted-function?
The example script is below,
Any help will be greatly appreciated!
Thank you!
import torch
import torch.nn as nn
from typing import List, Optional
from torch import Tensor
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
def forward(self, x):
return x.pow(2).sum(dim=-1)
net = Net()
@torch.jit.script
def sumit(inp: List[Optional[torch.Tensor]]):
elt = inp[0]
if elt is None:
raise RuntimeError("blah")
base = elt
for i in range(1, len(inp)):
next_elt = inp[i]
if next_elt is None:
raise RuntimeError("blah")
base = base + next_elt
return base
@torch.jit.script
def laplacian_jit(net: nn.Module, xs: Tensor):
xis = [xi.requires_grad_() for xi in xs.flatten(start_dim=1).t()]
xs_flat = torch.stack(xis, dim=1)
ys = net(xs_flat.view_as(xs))
ones = torch.ones_like(ys)
grad_outputs = torch.jit.annotate(List[Optional[Tensor]], [])
grad_outputs.append(ones)
result = torch.autograd.grad([ys], [xs_flat], grad_outputs, retain_graph=True, create_graph=True)
dy_dxs = result[0]
if dy_dxs is None:
raise RuntimeError("blah")
generator_as_list = [dy_dxs[..., i] for i in range(len(xis))]
lap_ys_components = [torch.autograd.grad([dy_dxi], [xi], grad_outputs, retain_graph=True, create_graph=False)[0] \
for xi, dy_dxi in zip(xis,generator_as_list)]
lap_ys = sumit(lap_ys_components)
return lap_ys
x = torch.randn(4096,2)
laplacian_jit(x)