I run into this interesting issue when trying to export a simple
nn.Module. Here is the code:
import torch import torch.nn as nn class MyMod(nn.Module): def __init__(self) -> None: super().__init__() self.x = 1 self.y = 2 def forward(self, x): return 2 * x # @torch.jit.export # <-- UNCOMMENT HERE def get_x(self): return self.x mod = MyMod() traced = torch.jit.trace(mod, [torch.tensor([1.0])]) print(traced.get_x()) print(traced.x) print(traced.y)
At the current form the code fails on all 3 print statements - AS EXPECTED.
Now when I uncomment the
torch.jit.export annotation all 3 statements pass.
- I understand
get_x(), because I explicitly exported it.
- I kind of get
x, because it’s a dependency for
- I totally don’t get
y, because it’s just an independent non-tensor property.
Tested on torch==2.0.1.
Most of the time I wouldn’t really care about some extra properties exported alongside my model.
The problem occurs when I try to use some external libraries like HuggingFace
For example, when I add
ViTForImageClassification as a submodule of
MyMod all it’s properties
are going to be recursively traced and this leads to annotation errors like the one below:
Unknown type name 'nn.Module': File "blah/site-packages/transformers/modeling_utils.py", line 1197 @property def base_model(self) -> nn.Module:
This is just some innocent annotated property that gets cought in the crossfire.
Any idea if this is an expected behaior. Should I report a bug? I already found an ugly workaround, but this feels like something that should have a simpler solution.
Workaround for anyone interested
Just create a new
MyExports module and move the exports there. Since the behavior is recursive it will be limited to the
MyExports module and leave the ViT model alone. You can then access them like