Graph tracing false when meeting tensor slicing operation

model forward code snippet:

    def forward(self, inputs):
        x = self.aggregator(inputs)
        x1, x2 = x[:, :self.dim_extract], x[:, self.dim_extract:]
        x1 = self.extractor_fraction_1(x1)
        x1 = self.extractor_fraction_2(x1)
        x = torch.cat([x1, x2], dim=1) # concat along channel dimension
        return x

error detail:

  File "/mnt/afs/user/jarvis/miniconda3/envs/ntire2024/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 765, in conv
    conv_backend = torch._C._select_conv_backend(**kwargs)
torch._dynamo.exc.TorchRuntimeError: Failed running call_module L__self___mpfd_blocks_0_out_1_partial_conv_extractor_fraction_1(*(FakeTensor(..., size=(16, 8, 348), grad_fn=<SliceBackward0>),), **{}):
Given groups=1, weight of size [8, 8, 5, 5], expected input[1, 16, 8, 348] to have 8 channels, but got 16 channels instead

from user code:
   File "/mnt/afs/user/jarvis/projects/efficient-sr/NTIRE2024_ESR_Challenge/NTIRE2024_ESR/src/model/fan.py", line 247, in forward
    x_forward = mpfd_block(x_forward)
  File "/mnt/afs/user/jarvis/miniconda3/envs/ntire2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/user/jarvis/projects/efficient-sr/NTIRE2024_ESR_Challenge/NTIRE2024_ESR/src/model/fan.py", line 189, in forward
    out_1 = self.out_1(inputs)
  File "/mnt/afs/user/jarvis/miniconda3/envs/ntire2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/user/jarvis/projects/efficient-sr/NTIRE2024_ESR_Challenge/NTIRE2024_ESR/src/model/fan.py", line 113, in forward
    x = self.partial_conv(inputs)
  File "/mnt/afs/user/jarvis/miniconda3/envs/ntire2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/afs/user/jarvis/projects/efficient-sr/NTIRE2024_ESR_Challenge/NTIRE2024_ESR/src/model/fan.py", line 84, in forward
    x1 = self.extractor_fraction_1(x1)

[2024-03-09 18:14:22,879] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2024-03-09 18:14:22,879] torch._dynamo.utils: [INFO] Function                           Runtimes (s)
[2024-03-09 18:14:22,879] torch._dynamo.utils: [INFO] -------------------------------  --------------
[2024-03-09 18:14:22,879] torch._dynamo.utils: [INFO] _compile.<locals>.compile_inner               0

I’m using torch 2.2 stable to do QAT and encounter the described problem when using torch Dynamo for graph capture. During the normal forward loop, the tensor x1’s shape is [1, 16, 210, 348], while during trace, the shape is [1, 16, 8, 348]. No idea about what causes that. Any help, please~

can you give a minimal repro, its hard to follow from just what you’ve included

Sure! Here is the link of where the error sticks. The forward logic with graph capture error accepts the output of a depthwise convolution as input. And when I call the capture_pre_autograd_graph function, the shape incompatible error occurs.

can you run the model without the graph capture?

Yes, it works well to complete the entire forward train logic and save the best ckpt.

this is an export problem, can you post in torch.compile - PyTorch Forums?

Get it, thanks for your reminder