model forward code snippet:
def forward(self, inputs):
x = self.aggregator(inputs)
x1, x2 = x[:, :self.dim_extract], x[:, self.dim_extract:]
x1 = self.extractor_fraction_1(x1)
x1 = self.extractor_fraction_2(x1)
x = torch.cat([x1, x2], dim=1) # concat along channel dimension
return x
error detail:
File "/mnt/afs/user/jarvis/miniconda3/envs/ntire2024/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 765, in conv
conv_backend = torch._C._select_conv_backend(**kwargs)
torch._dynamo.exc.TorchRuntimeError: Failed running call_module L__self___mpfd_blocks_0_out_1_partial_conv_extractor_fraction_1(*(FakeTensor(..., size=(16, 8, 348), grad_fn=<SliceBackward0>),), **{}):
Given groups=1, weight of size [8, 8, 5, 5], expected input[1, 16, 8, 348] to have 8 channels, but got 16 channels instead
from user code:
File "/mnt/afs/user/jarvis/projects/efficient-sr/NTIRE2024_ESR_Challenge/NTIRE2024_ESR/src/model/fan.py", line 247, in forward
x_forward = mpfd_block(x_forward)
File "/mnt/afs/user/jarvis/miniconda3/envs/ntire2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/afs/user/jarvis/projects/efficient-sr/NTIRE2024_ESR_Challenge/NTIRE2024_ESR/src/model/fan.py", line 189, in forward
out_1 = self.out_1(inputs)
File "/mnt/afs/user/jarvis/miniconda3/envs/ntire2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/afs/user/jarvis/projects/efficient-sr/NTIRE2024_ESR_Challenge/NTIRE2024_ESR/src/model/fan.py", line 113, in forward
x = self.partial_conv(inputs)
File "/mnt/afs/user/jarvis/miniconda3/envs/ntire2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/afs/user/jarvis/projects/efficient-sr/NTIRE2024_ESR_Challenge/NTIRE2024_ESR/src/model/fan.py", line 84, in forward
x1 = self.extractor_fraction_1(x1)
[2024-03-09 18:14:22,879] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2024-03-09 18:14:22,879] torch._dynamo.utils: [INFO] Function Runtimes (s)
[2024-03-09 18:14:22,879] torch._dynamo.utils: [INFO] ------------------------------- --------------
[2024-03-09 18:14:22,879] torch._dynamo.utils: [INFO] _compile.<locals>.compile_inner 0
I’m using torch 2.2 stable to do QAT and encounter the described problem when using torch Dynamo for graph capture. During the normal forward loop, the tensor x1’s shape is [1, 16, 210, 348], while during trace, the shape is [1, 16, 8, 348]. No idea about what causes that. Any help, please~