Replacing up-sampling with transposed convolutions

I have a UNet++ and I’m trying to replace the up-sample operations with a transposed conventional operation. Is there a more straightforward way than defining every new conventional operations within the UNetPP class object?

class blockUNetPP(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out

class UNetPP(nn.Module):
    def __init__(self, channelExponent=6, dropout=0.,input_channels=3):
        super().__init__()

        channels = int(2 ** channelExponent + 0.5)
        self.up = nn.Upsample(scale_factor=2)

        self.conv0_0 = blockUNetPP(input_channels, channels, channels)
        
        self.bbconv1_0 = nn.Conv2d(channels, channels, 4, 2, 1)
        self.conv1_0 = blockUNetPP(channels, channels*2, channels*2)
        
        self.bbconv2_0 = nn.Conv2d(channels*2, channels*2, 4, 2, 1)
        self.conv2_0 = blockUNetPP(channels*2, channels*4, channels*4)
        
        self.bbconv3_0 = nn.Conv2d(channels*4, channels*4, 4, 2, 1)
        self.conv3_0 = blockUNetPP(channels*4, channels*8, channels*8)
        
        self.bbconv4_0 = nn.Conv2d(channels*8, channels*8, 4, 2, 1)
        self.conv4_0 = blockUNetPP(channels*8, channels*16, channels*16)

        self.conv0_1 = blockUNetPP(channels+channels*2, channels, channels)
        
        
    
        self.conv1_1 = blockUNetPP(channels*2+channels*4, channels*2, channels*2)
        self.conv2_1 = blockUNetPP(channels*4+channels*8, channels*4, channels*4)
        self.conv3_1 = blockUNetPP(channels*8+channels*16, channels*8, channels*8)

        self.conv0_2 = blockUNetPP(channels*2+channels*2, channels, channels)
        self.conv1_2 = blockUNetPP(channels*2*2+channels*4, channels*2, channels*2)
        self.conv2_2 = blockUNetPP(channels*4*2+channels*8, channels*4, channels*4)

        self.conv0_3 = blockUNetPP(channels*3+channels*2, channels, channels)
        self.conv1_3 = blockUNetPP(channels*2*3+channels*4, channels*2, channels*2)

        self.conv0_4 = blockUNetPP(channels*4+channels*2, channels, channels)

        
        self.final = nn.Conv2d(channels, 3, kernel_size=1)

    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.bbconv1_0(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

        x2_0 = self.conv2_0(self.bbconv2_0(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

        x3_0 = self.conv3_0(self.bbconv3_0(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

        x4_0 = self.conv4_0(self.bbconv4_0(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))


        output = self.final(x0_4)
        return output

I think you would unfortunately have to define separate transposed convolution layers, since the current code just reuses self.up.
While this is correct for nn.Upsample, as it doesn’t contain trainable parameters, most likely you don’t want to reuse the nn.ConvTranpose2d layer.

You could try to define a new submodule, which could accept two inputs and use the transposed conv internally, but you would still have to change some modules.

1 Like

Thanks mate, guess I’ll be breaking out the conventional abacus again…