RuntimeError: Calculated padded input size per channel: (2 x 2). Kernel size: (3 x 3). Kernel size can't be greater than actual input size

Hi,

I am trying to train a UNET model using a modified implementation of this code example from pytorch lightening.

I’m running into an issue (I believe it may be in the contraction block) where my kernel size is larger than the input image. By the end of the contraction, I end up with an image of size torch.Size([8,64,2,2]), which is quite obviously too small to be convolved.

Here is the code for the Unet() object:

class UNet(pl.LightningModule):
    def __init__(self,
                 in_channels,
                 output_channels,
                 hidden_channels=64,
                 depth=3):
        super(UNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
        self.conv_final = nn.Conv2d(hidden_channels,
                                    output_channels,
                                    kernel_size=1)
        self.depth = depth

        self.contracting_layers = []
        self.expanding_layers = []

        for i in range(0, depth):
            self.contracting_layers += [
                ContractingBlock(hidden_channels * 2**i)
            ]

        for i in range(1, depth + 1):
            self.expanding_layers += [ExpandingBlock(hidden_channels * 2**i)]

        self.contracting_layers = nn.ModuleList(self.contracting_layers)
        self.expanding_layers = nn.ModuleList(self.expanding_layers)

    def forward(self, x):
        depth = self.depth
        contractive_x = []

        x = self.conv1(x)
        contractive_x.append(x)

        for i in range(depth):
            x = self.contracting_layers[i](x)
            contractive_x.append(x)

        for i in range(depth - 1, -1, -1):
            x = self.expanding_layers[i](x, contractive_x[i])
        x = self.conv_final(x)

        return x

    def training_step(self, batch, batch_idx):
        x, y = batch['image'], batch['mask']

        y_pred = self.forward(x)
        loss = criterion(y_pred, y)

        self.log('loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=lr)

Which relies on the following DoubleConv, Contracting, UpsampleConv, and Expanding objects:

Other blocks
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, (3, 3)),
            nn.ReLU(inplace=True),
            PrintLayer(),
            nn.Conv2d(out_channels, out_channels, (3, 3)),
            nn.ReLU(inplace=True),
            PrintLayer()
        )

    def forward(self, x):
        return self.net(x)

class ContractingBlock(nn.Module):
    def __init__(self, in_channels):
        super(ContractingBlock, self).__init__()
        # first a conv (3x3, no padding), relu, conv 3x3, relu, max_pool (2x2, stride 2)
        self.double_conv = DoubleConv(in_channels, in_channels * 2)
        self.pooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.double_conv(x)
        x = self.pooling(x)

        return x

class UpsampleConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(UpsampleConv, self).__init__()

        self.net = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
            PrintLayer(),
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size),
            PrintLayer())

    def forward(self, x):
        return self.net(x)

class ExpandingBlock(nn.Module):
    def __init__(self, in_channels):
        super(ExpandingBlock, self).__init__()
        self.upsample = UpsampleConv(in_channels, in_channels // 2)
        self.double_conv = DoubleConv(in_channels, in_channels // 2)

    def forward(self, x, skip_conn):
        x = self.upsample(x)

        # crop skip_conn and add to upsampled x
        cropped_skip_conn = crop(skip_conn, x.shape)
        x = torch.cat([cropped_skip_conn, x], axis=1)

        x = self.double_conv(x)
        return x

Thanks to @PA_Nik, I threw in a custom PrintLayer, which allows me to examine the input dimensions which are as follows:

torch.Size([8, 128, 62, 62])
torch.Size([8, 128, 60, 60])
torch.Size([8, 256, 28, 28])
torch.Size([8, 256, 26, 26])
torch.Size([8, 512, 11, 11])
torch.Size([8, 512, 9, 9])
torch.Size([8, 512, 10, 10])
torch.Size([8, 256, 8, 8])
torch.Size([8, 256, 6, 6])
torch.Size([8, 256, 4, 4])
torch.Size([8, 256, 8, 8])
torch.Size([8, 128, 6, 6])
torch.Size([8, 128, 4, 4])
torch.Size([8, 128, 2, 2])
torch.Size([8, 128, 4, 4])
torch.Size([8, 64, 2, 2])

Could somebody please guide me through the process by which I can add the requisite paddings or contraction procedures to keep the dimensionality of my network from contracting too far?

Is there an easy way to check which part of the network is contracting too far?

A good approach is to check the activation shapes after each layer, which you’ve already done.
Now you can check which layer should keep or increase the spatial size of the activation, e.g. you could start by dropping some pooling layers and check how many layers would need to be dropped.