Hello there,
I’ve faced some unexpected behaviour of torch.jit.script()
. The issue may be reproduced with the following code.
As I learned from docs ( https://pytorch.org/docs/stable/jit.html#id3 ) one can use custom TorchScript classe if it’s properly written. But…
import torch
from typing import Tuple
@torch.jit.script
class MyClass(object):
def __init__(self, weights = (1.0, 1.0, 1.0, 1.0,)):
# type: (Tuple[float, float, float, float])
self.weights = weights
def apply(self):
# type: () -> Tuple[float, float, float, float]
return self.weights
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.class_field = torch.jit.export(MyClass(weights = (1.0, 1.0, 1.0, 1.0,)))
def forward(self, x):
self.class_field.apply()
return x + 10
m = torch.jit.script(MyModule())
Produces such error:
RuntimeError:
Module 'MyModule' has no attribute 'class_field' (This attribute exists on the Python module, but we failed to convert Python type: 'MyClass' to a TorchScript type.):
at script_test.py:20:8
def forward(self, x):
self.class_field.apply()
~~~~~~~~~~~~~~~~ <--- HERE
return x + 10