Hi,
I am trying to use Python dataclass with PyTorch JIT. And it works for a simple data class. For example,
import torch
from dataclasses import dataclass
@torch.jit.script
@dataclass
class SimpleView:
id: int
name: str
def test_simple_view():
view = SimpleView(
id=1,
name='foo',
)
assert view.name == 'foo'
print(view)
But when I tried to use more complex dataclasses, for example, one dataclass is embedded inside another dataclass, it does not work because JIT fails to compile such dataclass. Here is the code sample,
import torch
from dataclasses import dataclass
@torch.jit.script
@dataclass
class ValueStruct:
label: str
score: float
@torch.jit.script
@dataclass
class EmbeddedView:
id: int
name: str
value: ValueStruct
def test_embedded_view():
view = EmbeddedView(
id=2,
name='bar',
value=ValueStruct(
label='baz',
score=9.9,
)
)
assert view.value.label == 'baz'
print(view)
And this is the error I got when running the above test case,
============================= test session starts ==============================
collecting ...
main.py:None (main.py)
main.py:48: in <module>
class EmbeddedView:
../../python/py39_1/lib/python3.9/site-packages/torch/jit/_script.py:1323: in script
_compile_and_register_class(obj, _rcb, qualified_name)
../../python/py39_1/lib/python3.9/site-packages/torch/jit/_recursive.py:47: in _compile_and_register_class
script_class = torch._C._jit_script_class_compile(qualified_name, ast, defaults, rcb)
E RuntimeError:
E Unknown type name 'main.ValueStruct':
E File "__torch_jit_dataclass/EmbeddedView/__init__", line 0
E def __init__(self, id: int, name: str, value: main.ValueStruct) -> None:
E ~~~~~~~~~~~~~~~~ <--- HERE
E self.id = id
E self.name = name
ERROR: not found: /Users/dan.xu/dev/projects/ts_dataclass/main.py::test_embedded_view
collected 0 items / 1 error
(no name '/Users/dan.xu/dev/projects/ts_dataclass/main.py::test_embedded_view' in any of [<Module main.py>])
=============================== 1 error in 1.90s ===============================
Process finished with exit code 4
I found the dataclasses support in TorchScript was added last year in these PRs(#72901, #76771), and the version of PyTorch I am using is 1.13.1, which should contain the dataclasses support. I am wondering if there is a way to support nested dataclasses, and if this is a problem for the dataclass TorchScript support. Please help provide your advice, thanks!