Torch.compile does not work with importlib

The directory structure is like this:

.
├── main.py
└── models
    └── model.py

The content of main.py is:

import torch
import importlib


def get_model():

    pth = './models/model.py' # this is assigned as input argument
    spec = importlib.util.spec_from_file_location('model', pth)
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)

    class Child(mod.Block):

        def __init__(self):
            super(Child, self).__init__()

        @torch.compile
        def forward(self, x):
            outputs = super().forward(x)
            return outputs

    return Child

model = get_model()()
model = torch.compile(model)

optim = torch.optim.SGD(model.parameters(), lr=1e-3)

for i in range(10):
    inten = torch.randn(2, 32, 128, 128)
    outen = torch.randn(2, 32, 128, 128)

    optim.zero_grad()
    out = model(inten)
    loss = (out - outen).mean()
    loss.backward()
    optim.step()
    print(loss.item())

And the content of models/model.py is:

import torch

class Block(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(32, 32, 3, 1, 1)
        self.conv2 = torch.nn.Conv2d(32, 32, 3, 1, 1)

    def forward(self, x):
        return self.conv2(self.conv1(x))

The error message is:

Traceback (most recent call last):
  File "/home/projects/useless/main.py", line 55, in <module>
    out = model(inten)
  File "/home/.conda/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/.conda/envs/py39/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/home/.conda/envs/py39/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
TypeError: forward() missing 1 required positional argument: 'x'

I am working on ubuntu system, with pytorch 2.0.1 installed from conda, the cuda versio is 11.8.

If I comment out the line of @torch.compile, the code can work as expected. Would you tell me how could I make this work?