How to change UNET backbone to mobilenetv3_large_100?

Below is my code for a UNET model. How can I change its backbone to the CNN model mobilenetv3_large_100? Also, the images I wish to train it on are grayscale.

class double_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()


        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                        diffY // 2, diffY - diffY//2))



        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x

class UNET(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNET, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256)
        self.up2 = up(512, 128)
        self.up3 = up(256, 64)
        self.up4 = up(128, 64)
        self.outc = outconv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return torch.sigmoid(x)

Hi Nitay!

Probably the most practical approach would be to rewrite your U-Net
to use MobileNet’s “depthwise separable convolutions” and follow its
structure in how much it downsamples in each block.

The issue is that you have to hold on to references to the outputs of the
downsampling blocks so that you can cat() them together with the
upsampled tensors (U-Net’s “cross connections”) as you work up through
the upsampling blocks.

Hypothetically, you could instantiate a MobileNetV3 and perform runtime
“surgery”, but I don’t see a clear path to performing all of the necessary
steps, and, even if you could do so, it seems like writing a MobileNet-based
U-Net would be less of a pain in the neck than the surgery approach.

You don’t say anything about using a pretrained model, but if you write
your own model and are careful about keeping the same names for all
of the layers, you ought to be able to load a pretrained state_dict from
a pretrained MobileNet into your model.

(If you were to succeed with the runtime surgery approach, you could
simply perform the surgery on a pretrained MobileNet.)

The original MobileNet starts of with a Conv2d with in_channels = 3,
corresponding to an RGB input. Simply change in_channels to 1 for
a grayscale input. (Alternatively, you could simply expand() your inputs
to have three channels, giving you an RGB input that is “grayscale”
because it has R = G = B.)

As an aside, both MobileNet and your U-Net use padding (so that the
outputs of the convolutions have the same spatial extent as their inputs).
The original U-Net very specifically does not do this so as to keep the
“images” on either end of the cross connections properly aligned, avoiding
border artifacts when they are cat()ed together. You might consider
following the original U-Net strategy in this regard.

Best.

K. Frank