Mat1 and mat2 shapes cannot be multiplied (8x64 and 3x0)

I’ve created a SqueezeExcitation embedded Attention UNET and here is my model

# UNET and it's parts
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_op = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv_op(x)

class SqueezeExcitation(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.flat = nn.Flatten()
        self.ch = nn.Sequential(
            nn.Linear(in_features=channels, out_features=channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )
    def forward(self, x):
        out = self.avg_pool(x)
        out = self.flat(out)
        out = self.ch(out)
        out = out.view(out.size(0), out.size(1), 1, 1)
        return x * out

class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.drop = nn.Dropout2d(0.2)
        self.se = SqueezeExcitation(in_channels)
        
    def forward(self, x):
        down = self.conv(x)
        down = self.se(down)
        down = self.drop(down)
        p = self.pool(down)
        return down, p

class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
       x1 = self.up(x1)
       x = torch.cat([x1, x2], 1)
       return self.conv(x)

class Attention(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.normal = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.down = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=2)
        self.one = nn.Conv2d(out_channels, 1, kernel_size=1)

        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()
        
        self.resample = nn.Upsample(scale_factor=2)

    def forward(self, X, skip_X):
        x = self.normal(X)
        skip = self.down(skip_X)
        x = x + skip

        x = self.relu(x)
        x = self.one(x)
        x = self.sigmoid(x)
        
        x = self.resample(x)
        return x * skip_X

class UNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(UNet, self).__init__()
        self.down_convolution_1 = DownSample(in_channels, 64)
        self.down_convolution_2 = DownSample(64, 128)
        self.down_convolution_3 = DownSample(128, 256)
        self.down_convolution_4 = DownSample(256, 512)

        self.bottle_neck = DoubleConv(512, 1024)

        self.attention_gate_1 = Attention(1024, 512)
        self.up_convolution_1 = UpSample(1024, 512)
        self.attention_gate_2 = Attention(512, 256)
        self.up_convolution_2 = UpSample(512, 256)
        self.attention_gate_3 = Attention(256, 128)
        self.up_convolution_3 = UpSample(256, 128)
        self.attention_gate_4 = Attention(128, 64)
        self.up_convolution_4 = UpSample(128, 64)

        self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)

    def forward(self, x):
        down_1, p1 = self.down_convolution_1(x)
        down_2, p2 = self.down_convolution_2(p1)
        down_3, p3 = self.down_convolution_3(p2)
        down_4, p4 = self.down_convolution_4(p3)

        b = self.bottle_neck(p4)

        a1 = self.attention_gate_1(b, down_4)
        up_1 = self.up_convolution_1(b, a1)
        a2 = self.attention_gate_2(up_1, down_3)
        up_2 = self.up_convolution_2(up_1, a2)
        a3 = self.attention_gate_3(up_2, down_2)
        up_3 = self.up_convolution_3(up_2, a3)
        a4 = self.attention_gate_4(up_3, down_1)
        up_4 = self.up_convolution_4(up_3, a4)

        out = self.out(up_4)
        return torch.sigmoid(out)

Now hyperparameter BATCH_SIZE = 8, INPUT_SHAPE_UNET=(8,3,256,256)

The part I’ve used to check if my model is working is below

dummy_input = torch.randn((8,3,256,256))
model = UNet(3,1)
model.forward(dummy_input)

Now the error part

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[143], line 4
      2 dummy_input = torch.randn((8,3,256,256))
      3 model = UNet(3,1)
----> 4 model.forward(dummy_input)

Cell In[142], line 106, in UNet.forward(self, x)
    105 def forward(self, x):
--> 106     down_1, p1 = self.down_convolution_1(x)
    107     down_2, p2 = self.down_convolution_2(p1)
    108     down_3, p3 = self.down_convolution_3(p2)

File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

Cell In[142], line 44, in DownSample.forward(self, x)
     42 def forward(self, x):
     43     down = self.conv(x)
---> 44     down = self.se(down)
     45     down = self.drop(down)
     46     p = self.pool(down)

File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

Cell In[142], line 30, in SqueezeExcitation.forward(self, x)
     28 out = self.avg_pool(x)
     29 out = self.flat(out)
---> 30 out = self.ch(out)
     31 out = out.view(out.size(0), out.size(1), 1, 1)
     32 return x * out

File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\container.py:217, in Sequential.forward(self, input)
    215 def forward(self, input):
    216     for module in self:
--> 217         input = module(input)
    218     return input

File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\linear.py:116, in Linear.forward(self, input)
    115 def forward(self, input: Tensor) -> Tensor:
--> 116     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (8x64 and 3x0)

PLEASE HELP ME to solve the issue.asap @ptrblck and others.

Probably, you need to change self.se = SqueezeExcitation(in_channels) with self.se = SqueezeExcitation(out_channels) in DownSample class

Yeah, Thank you @mhm2020 for your help.
Actually after posting this problem, I got to know the same thing, so thank you for helping me by replying to this problem-related post.