Hi I’m trying to debug an unrelated issue and I was trying to construct a very basic test case.
However, I’m getting a different, error when I am running my basic test case that makes me question if I understand the basics of jit compilation.
I defined the following two dummy classes
class Dummy1():
def __init__(self):
self.__name__ = "dummy"
def forward(self, x:Tensor)-> Tuple[Tensor, Optional[Tensor]]:
return (x,x)
class Dummy2(Dummy1):
def __init__(self):
super().__init__()
def forward(self, x:Tensor)->Tuple[Tensor, Optional[Tensor]]:
return super.forward(x)
And I get the following error when trying to torch.jit.script an object of Dummy2
>>> import torch
>>> torch.__version__
'1.8.1'
>>> import multihead_attention_window
>>> dummy = multihead_attention_window.Dummy2()
>>> torch.jit.script(dummy)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/psridhar/.pyenv/versions/py3/lib/python3.8/site-packages/torch/jit/_script.py", line 986, in script
ast = get_jit_def(obj, obj.__name__)
File "/Users/psridhar/.pyenv/versions/py3/lib/python3.8/site-packages/torch/jit/frontend.py", line 240, in get_jit_def
sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack())
File "/Users/psridhar/.pyenv/versions/py3/lib/python3.8/site-packages/torch/_utils_internal.py", line 54, in get_source_lines_and_file
filename = inspect.getsourcefile(obj)
File "/Users/psridhar/.pyenv/versions/3.8.1/lib/python3.8/inspect.py", line 696, in getsourcefile
filename = getfile(object)
File "/Users/psridhar/.pyenv/versions/3.8.1/lib/python3.8/inspect.py", line 676, in getfile
raise TypeError('module, class, method, function, traceback, frame, or '
TypeError: module, class, method, function, traceback, frame, or code object was expected, got Dummy2
I googled around and noticed this might be because of a torch version issue? But I’m on torch 1.8.1. What am I missing here?