Torch.jit.script(module) vs @torch.jit.script decorator

What is the purpose of @torch.jit.script decorator? Why is adding the decorator “@torch.jit.script” results in an error, while I can call torch.jit.script on that module, e.g. this fails:

import torch

@torch.jit.script
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.script(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)

"C:\Users\Administrator\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\torch\jit\__init__.py", line 1262, in script
    raise RuntimeError("Type '{}' cannot be compiled since it inherits"
RuntimeError: Type '<class '__main__.MyCell'>' cannot be compiled since it inherits from nn.Module, pass an instance instead

While the following code works well:

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.script(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)

@torch.jit.script can be used as a decorator on functions to script them. So, these two code snippets are roughly equivalent:

Decorator

@torch.jit.script
def fn(a: int):
   return a + 1

fn(3) # fn here is a scripted function

Function Call

def fn(a: int):
   return a + 1

s_fn = torch.jit.script(fn)
s_fn(3) # s_fn here is a scripted function

This decorator can also be used on classes that extend object to script them (known as Torchscript classes).

Because only instances of Modules can be scripted, @torch.jit.script cannot be used as a decorator on a subclass of Module. You must create an instance of the Module and pass it to torch.jit.script in order to script it.