I run into this interesting issue when trying to export a simple nn.Module
. Here is the code:
import torch
import torch.nn as nn
class MyMod(nn.Module):
def __init__(self) -> None:
super().__init__()
self.x = 1
self.y = 2
def forward(self, x):
return 2 * x
# @torch.jit.export # <-- UNCOMMENT HERE
def get_x(self):
return self.x
mod = MyMod()
traced = torch.jit.trace(mod, [torch.tensor([1.0])])
print(traced.get_x())
print(traced.x)
print(traced.y)
At the current form the code fails on all 3 print statements - AS EXPECTED.
Now when I uncomment the torch.jit.export
annotation all 3 statements pass.
- I understand
get_x()
, because I explicitly exported it. - I kind of get
x
, because it’s a dependency forget_x()
- I totally don’t get
y
, because it’s just an independent non-tensor property.
Tested on torch==2.0.1.
Most of the time I wouldn’t really care about some extra properties exported alongside my model.
The problem occurs when I try to use some external libraries like HuggingFace transformers
.
For example, when I add ViTForImageClassification
as a submodule of MyMod
all it’s properties
are going to be recursively traced and this leads to annotation errors like the one below:
Unknown type name 'nn.Module':
File "blah/site-packages/transformers/modeling_utils.py", line 1197
@property
def base_model(self) -> nn.Module:
This is just some innocent annotated property that gets cought in the crossfire.
Any idea if this is an expected behaior. Should I report a bug? I already found an ugly workaround, but this feels like something that should have a simpler solution.
Workaround for anyone interested
Just create a new MyExports
module and move the exports there. Since the behavior is recursive it will be limited to the MyExports
module and leave the ViT model alone. You can then access them like mod.exports.get_x()
.