Unet implementation, how to fix this. I tried changing the padding,stride, and kernel size but it didn't works

class DoubleConv(nn.Module):
def init(self, in_channels, out_channels):
super(DoubleConv, self).init()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)

def forward(self, x):
    return self.conv(x)

class UNetAudio(nn.Module):
def init(self, in_channels, out_channels):
super(UNetAudio, self).init()
self.in_channels = in_channels
self.out_channels = out_channels

    self.down1 = DoubleConv(in_channels, 64)
    self.down2 = DoubleConv(64, 128)
    self.down3 = DoubleConv(128, 256)
    self.down4 = DoubleConv(256, 512)

    self.up1 = nn.ConvTranspose2d(512, 256, 2, 2)
    self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
    self.up3 = nn.ConvTranspose2d(128, 64, 2,2)
    
    self.out = nn.Conv2d(64, out_channels, 1)

def forward(self, x):
    # Encoder
    x1 = self.down1(x)
    x2 = self.down2(nn.MaxPool2d(2)(x1))
    x3 = self.down3(nn.MaxPool2d(2)(x2))
    x4 = self.down4(nn.MaxPool2d(2)(x3))

    # Decoder
    x = self.up1(x4)
    print(x.shape)
    print(x2.shape)
    print(x3.shape)
    x = self.up2(torch.cat([x, x3], dim=1))
    x = self.up3(torch.cat([x, x2], dim=1))
    x = self.out(torch.cat([x, x1], dim=1))
    
    return x

torch.Size([32, 256, 32, 26])
torch.Size([32, 256, 32, 27])

RuntimeError Traceback (most recent call last)
Cell In[28], line 1
----> 1 training(train_loader, test_loader, num_epochs, model, loss_function, optimiser, scheduler, “UNetAudio”)

Cell In[16], line 11, in training(train_loader, val_loader, epochs, model, criterion, optimizer, scheduler, model_name)
9 batch_Y=Variable(batch_Y.to(device))
10 optimizer.zero_grad()
—> 11 outputs = model(batch_X)
12 loss = criterion(outputs.squeeze(), batch_Y)
13 loss.backward()

File D:\anaconda\envs\Pytorch\lib\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
→ 1518 return self._call_impl(*args, **kwargs)

File D:\anaconda\envs\Pytorch\lib\site-packages\torch\nn\modules\module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don’t have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

Cell In[26], line 45, in UNetAudio.forward(self, x)
43 print(x2.shape)
44 print(x3.shape)
—> 45 x = self.up2(torch.cat([x, x3], dim=1))
46 x = self.up3(torch.cat([x, x2], dim=1))
47 x = self.out(torch.cat([x, x1], dim=1))

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 26 but got size 27 for tensor number 1 in the list.

Hi Ayush!

The short story is that the classic U-Net architecture only accepts inputs
of certain specific shapes (although these shapes can vary, and – given
adequate memory – can be arbitrarily large). Your input has an invalid
shape.

I’m not going to sort through your specific model to determine its valid
input shapes, but what is going on is that when you “descend” through
the U, the convolutions nibble away at the edges of your image (typically
reducing its extent by two – one pixel on each side) and then your
downsampling (MaxPool2d (2)) shrinks your image by a factor of two.

If the image that is input to your MaxPool2d (2) has a dimension that
is not a multiple of two, no error occurs, but what would have been a
fractional output dimension is rounded down.

When you go back up the U, your upsampling (ConvTranspose2d)
expands your image by a factor of two. You then implement the “skip
connection” by concatenating the matching image from when you
descended the U. If one of the downsamplings between the two ends
of your skip connection rounded down, the shapes of your “matching”
images won’t be the same and torch.cat() will fail with the error you
see.

You have to track how the shape of your image changes as it moves
through the layers of your U-Net, look for such mismatches, and use
that analysis to determine what input shapes are valid. Then only input
images with valid shapes to your U-Net. (If you have an input image with
an invalid shape, my general recommendation would be to “reflection-pad”
that image up to the next larger valid shape.)

Good luck.

K. Frank