[JIT] Scripted attributes inside module

Hello there,

I’ve faced some unexpected behaviour of torch.jit.script(). The issue may be reproduced with the following code.
As I learned from docs ( https://pytorch.org/docs/stable/jit.html#id3 ) one can use custom TorchScript classe if it’s properly written. But…

import torch
from typing import Tuple

@torch.jit.script
class MyClass(object):
    def __init__(self, weights = (1.0, 1.0, 1.0, 1.0,)):
        # type: (Tuple[float, float, float, float])
        self.weights = weights
    def apply(self):
        # type: () -> Tuple[float, float, float, float]
        return self.weights

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.class_field = torch.jit.export(MyClass(weights = (1.0, 1.0, 1.0, 1.0,)))
                                        
    def forward(self, x):
        self.class_field.apply()
        return x + 10

m = torch.jit.script(MyModule())

Produces such error:

RuntimeError: 
Module 'MyModule' has no attribute 'class_field' (This attribute exists on the Python module, but we failed to convert Python type: 'MyClass' to a TorchScript type.):
at script_test.py:20:8
    def forward(self, x):
        self.class_field.apply()
        ~~~~~~~~~~~~~~~~ <--- HERE
        return x + 10

It seems that I’ve solved the problem. Found out that attributes introduced in __init__(self,...) should be explicitly annotated the following way:

class MyModule(torch.nn.Module):
    class_field : MyClass
    def __init__(self):
        super().__init__()
        self.class_field = torch.jit.export(MyClass(weights = (1.0, 1.0, 1.0, 1.0,)))                      

Ofk, if they are custom scripted classes.

P.S. The proper way to initialize custom class field is still unclear to me, I mean:

self.class_field = torch.jit.export(MyClass(weights = (1.0, 1.0, 1.0, 1.0,)))

or just

self.class_field = MyClass(weights = (1.0, 1.0, 1.0, 1.0,))

Behaviour is the same

2 Likes

This is a bug, I’ve filed https://github.com/pytorch/pytorch/issues/29597 since this should be something that we can do automatically.

export is meant to be used as a decorator on functions that need to be compiled but are not called from forward or anything forward calls. So for your code you can just get rid of the call to torch.jit.export. See these docs for details.

1 Like