Graph tracing false when using torch dynamo

problem description: I’m using torch 2.2.1 stable to do QAT and encounter the title described problem when using torch Dynamo for graph capture.
During the normal forward loop, the tensor x1’s shape is [1, 16, 210, 348], while during trace, the shape is [1, 16, 8, 348]. mini repo to follow what I‘m talking. No idea about what causes that.Any help, please~
code snippet:

def forward(self, inputs):
        x = self.aggregator(inputs)
        x1, x2 = x[:, :self.dim_extract], x[:, self.dim_extract:]
        x1 = self.extractor_fraction_1(x1)
        x1 = self.extractor_fraction_2(x1)
        x = torch.cat([x1, x2], dim=1) # concat along channel dimension
        return x

error detail:

File "/mnt/afs/user/jarvis/miniconda3/envs/ntire2024/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 765, in conv
    conv_backend = torch._C._select_conv_backend(**kwargs)
torch._dynamo.exc.TorchRuntimeError: Failed running call_module L__self___mpfd_blocks_0_out_1_partial_conv_extractor_fraction_1(*(FakeTensor(..., size=(16, 8, 348), grad_fn=<SliceBackward0>),), **{}):
Given groups=1, weight of size [8, 8, 5, 5], expected input[1, 16, 8, 348] to have 8 channels, but got 16 channels instead

from user code:
   File "/mnt/afs/user/jarvis/projects/efficient-sr/NTIRE2024_ESR_Challenge/NTIRE2024_ESR/src/model/fan.py", line 247, in forward
    x_forward = mpfd_block(x_forward)
  File "/mnt/afs/user/jarvis/miniconda3/envs/ntire2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/user/jarvis/projects/efficient-sr/NTIRE2024_ESR_Challenge/NTIRE2024_ESR/src/model/fan.py", line 189, in forward
    out_1 = self.out_1(inputs)
  File "/mnt/afs/user/jarvis/miniconda3/envs/ntire2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/user/jarvis/projects/efficient-sr/NTIRE2024_ESR_Challenge/NTIRE2024_ESR/src/model/fan.py", line 113, in forward
    x = self.partial_conv(inputs)
  File "/mnt/afs/user/jarvis/miniconda3/envs/ntire2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/user/jarvis/projects/efficient-sr/NTIRE2024_ESR_Challenge/NTIRE2024_ESR/src/model/fan.py", line 84, in forward
    x1 = self.extractor_fraction_1(x1)

[2024-03-09 18:14:22,879] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2024-03-09 18:14:22,879] torch._dynamo.utils: [INFO] Function                           Runtimes (s)
[2024-03-09 18:14:22,879] torch._dynamo.utils: [INFO] -------------------------------  --------------
[2024-03-09 18:14:22,879] torch._dynamo.utils: [INFO] _compile.<locals>.compile_inner               0