How to use PNASNet5 as encoder in Unet in pytorch

I want use PNASNet5Large as encoder for my Unet here is my wrong aproach for the PNASNet5Large but working for resnet:

class UNetResNet(nn.Module):
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
                 pretrained=False, is_deconv=False):
        super().__init__()
        self.num_classes = num_classes
        self.dropout_2d = dropout_2d

        if encoder_depth == 34:
            self.encoder = torchvision.models.resnet34(pretrained=pretrained)
            bottom_channel_nr = 512
        elif encoder_depth == 101:
            self.encoder = torchvision.models.resnet101(pretrained=pretrained)
            bottom_channel_nr = 2048
        elif encoder_depth == 152: #this works
            self.encoder = torchvision.models.resnet152(pretrained=pretrained)
            bottom_channel_nr = 2048
        elif encoder_depth == 777: #coded version for the pnasnet
            self.encoder = PNASNet5Large()
            bottom_channel_nr = 4320 #this unknown for me as well


        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Sequential(self.encoder.conv1,
                                   self.encoder.bn1,
                                   self.encoder.relu,
                                   self.pool)

        self.conv2 = self.encoder.layer1 #PNASNet5Large doesn't have such layers
        self.conv3 = self.encoder.layer2
        self.conv4 = self.encoder.layer3
        self.conv5 = self.encoder.layer4
        self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)
        
        self.dec5 =  DecoderBlock(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8,   is_deconv)
        self.dec4 = DecoderBlock(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec3 = DecoderBlock(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
        self.dec2 = DecoderBlock(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
                                   is_deconv)
        self.dec1 = DecoderBlock(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
        self.dec0 = ConvRelu(num_filters, num_filters)
        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)
        center = self.center(conv5)
        dec5 = self.dec5(torch.cat([center, conv5], 1))
        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(dec2)
        dec0 = self.dec0(dec1)
        return self.final(F.dropout2d(dec0, p=self.dropout_2d))
  1. How to get how many bottom channels pnasnet has. It ends up following way:


    self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864,
    in_channels_right=4320, out_channels_right=864)
    self.relu = nn.ReLU()
    self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0)
    self.dropout = nn.Dropout(0.5)
    self.last_linear = nn.Linear(4320, num_classes)

Is 4320 the answer or not, in_channels_left and out_channels_left - something new for me

  1. Resnet has somekind of 4 big layers which I use and encoders in my Unet arch, how get similar layer from pnasnet

I’m using pytorch 3.1 and this is the link to the Pnasnet directory

  1. AttributeError: ‘PNASNet5Large’ object has no attribute ‘conv1’ - so doesn’t have conv1 as well