The directory structure is like this:
.
├── main.py
└── models
└── model.py
The content of main.py
is:
import torch
import importlib
def get_model():
pth = './models/model.py' # this is assigned as input argument
spec = importlib.util.spec_from_file_location('model', pth)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
class Child(mod.Block):
def __init__(self):
super(Child, self).__init__()
@torch.compile
def forward(self, x):
outputs = super().forward(x)
return outputs
return Child
model = get_model()()
model = torch.compile(model)
optim = torch.optim.SGD(model.parameters(), lr=1e-3)
for i in range(10):
inten = torch.randn(2, 32, 128, 128)
outen = torch.randn(2, 32, 128, 128)
optim.zero_grad()
out = model(inten)
loss = (out - outen).mean()
loss.backward()
optim.step()
print(loss.item())
And the content of models/model.py
is:
import torch
class Block(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(32, 32, 3, 1, 1)
self.conv2 = torch.nn.Conv2d(32, 32, 3, 1, 1)
def forward(self, x):
return self.conv2(self.conv1(x))
The error message is:
Traceback (most recent call last):
File "/home/projects/useless/main.py", line 55, in <module>
out = model(inten)
File "/home/.conda/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/.conda/envs/py39/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File "/home/.conda/envs/py39/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
TypeError: forward() missing 1 required positional argument: 'x'
I am working on ubuntu system, with pytorch 2.0.1 installed from conda, the cuda versio is 11.8.
If I comment out the line of @torch.compile
, the code can work as expected. Would you tell me how could I make this work?