RuntimeError: Unknown type name 'nn.Module' - How to jit a function whose arguments contain an nn.Module?

Hi All,

I was just wondering how torch.jit.script can be used on functions that take nn.Module as an argument as well as Tensor?

For example, I have an example script below which takes the laplacian of a given function, and the laplacian_jit function takes in 2 arguments; the function, net, and the input x (of which we are taking the laplacian). However, when running this it fails with the following error,

Unknown type name 'nn.Module':
  File "", line 31
def laplacian_jit(net: nn.Module, xs: Tensor):
                       ~~~~~~~~~ <--- HERE
  xis = [xi.requires_grad_() for xi in xs.flatten(start_dim=1).t()]
  xs_flat = torch.stack(xis, dim=1)

It seems that JIT doesn’t support passing nn.Module as an argument type? Is there a way to define the type such that I can pass an nn.Module type object into the jitted-function?

The example script is below,

Any help will be greatly appreciated!

Thank you!

import torch
import torch.nn as nn

from typing import List, Optional
from torch import Tensor

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
  def forward(self, x):
    return x.pow(2).sum(dim=-1)
net = Net()

def sumit(inp: List[Optional[torch.Tensor]]):
  elt = inp[0]
  if elt is None:
      raise RuntimeError("blah")
  base = elt
  for i in range(1, len(inp)):
    next_elt = inp[i]
    if next_elt is None:
        raise RuntimeError("blah")
    base = base + next_elt
  return base

def laplacian_jit(net: nn.Module, xs: Tensor):
  xis = [xi.requires_grad_() for xi in xs.flatten(start_dim=1).t()]
  xs_flat = torch.stack(xis, dim=1)
  ys = net(xs_flat.view_as(xs))

  ones = torch.ones_like(ys)
  grad_outputs = torch.jit.annotate(List[Optional[Tensor]], [])
  result = torch.autograd.grad([ys], [xs_flat], grad_outputs, retain_graph=True, create_graph=True)
  dy_dxs = result[0]
  if dy_dxs is None:
      raise RuntimeError("blah")

  generator_as_list = [dy_dxs[..., i] for i in range(len(xis))]
  lap_ys_components = [torch.autograd.grad([dy_dxi], [xi], grad_outputs, retain_graph=True, create_graph=False)[0] \
                          for xi, dy_dxi in zip(xis,generator_as_list)]

  lap_ys = sumit(lap_ys_components)

  return lap_ys

x = torch.randn(4096,2)
