Checking segmentation model implementation

Hey all! I implemented this encoder-decoder semantic segmentation architecture from this paper and based off of the built-in PyTorch SqueezeNet code. I don’t get any errors, however, I’m not sure if I’ve designed it correctly (in applying it to my particular dataset with two classes, the loss goes down and the IoU goes up, but the actual outputs look awful; not sure if I just need to train it longer). Would anyone be able to give a glance over my model code, and tell me if there’s anything that seems obviously wrong? The code is fairly short. Thanks in advance!

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F

class Fire(nn.Module):

    def __init__(self, inplanes, squeeze_planes,
                 expand1x1_planes, expand3x3_planes):
        super(Fire, self).__init__()
        self.inplanes = inplanes
        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
        self.squeeze_activation = nn.ReLU(inplace=True)
        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
                                   kernel_size=1)
        self.expand1x1_activation = nn.ReLU(inplace=True)
        self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
                                   kernel_size=3, padding=1)
        self.expand3x3_activation = nn.ReLU(inplace=True)

    def forward(self, x):
        #import pdb; pdb.set_trace()
        x = self.squeeze_activation(self.squeeze(x))
        y = self.expand1x1_activation(self.expand1x1(x))
        z = self.expand3x3_activation(self.expand3x3(x))
        x = torch.cat((y,z), 1)
        return x


class FireDec(nn.Module):

    def __init__(self, inplanes, squeeze_planes,
                 expand1x1_planes, expand3x3_planes):
        super(FireDec, self).__init__()
        self.expand1x1 = nn.Conv2d(inplanes, expand1x1_planes,
                                   kernel_size=1)
        self.expand1x1_activation = nn.ReLU(inplace=True)
        self.expand3x3 = nn.Conv2d(inplanes, expand3x3_planes,
                                   kernel_size=3, padding=1)
        self.expand3x3_activation = nn.ReLU(inplace=True)
        self.squeeze = nn.Conv2d(expand1x1_planes + expand3x3_planes, squeeze_planes, kernel_size=1)
        self.squeeze_activation = nn.ReLU(inplace=True)

    def forward(self, x):
        y = self.expand1x1_activation(self.expand1x1(x))
        z = self.expand3x3_activation(self.expand3x3(x))
        x = torch.cat((y,z), 1)
        x = self.squeeze_activation(self.squeeze(x))
        return x

class SqueezeSegNetEncoder(nn.Module):

    def __init__(self, in_channels):
        super(SqueezeSegNetEncoder, self).__init__()

        self.feature_block1 = nn.Sequential(
            nn.Conv2d(in_channels, 96, kernel_size=7, stride=2),
            nn.ReLU(inplace=True),
        )
        self.feature_block2 = nn.Sequential(
            Fire(96, 16, 64, 64),
            Fire(128, 16, 64, 64),
            Fire(128, 32, 128, 128),
        )
        self.feature_block3 = nn.Sequential(
            Fire(256, 32, 128, 128),
            Fire(256, 48, 192, 192),
            Fire(384, 48, 192, 192),
            Fire(384, 64, 256, 256),
        )
        self.feature_block4 = nn.Sequential(
            Fire(512, 64, 256, 256),
        )

        # Final convolution is initialized differently form the rest
        final_conv = nn.Conv2d(512, 1000, kernel_size=1)
        self.classifier_conv = nn.Sequential(
            final_conv,
            nn.ReLU(inplace=True),
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m is final_conv:
                    init.normal_(m.weight, mean=0.0, std=0.01)
                else:
                    init.kaiming_uniform_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.feature_block1(x)
        dim1 = x.size()
        x, indices_1 = F.max_pool2d(x, kernel_size=3, stride=2, ceil_mode=True, return_indices=True)
        x = self.feature_block2(x)
        dim2 = x.size()
        x, indices_2 = F.max_pool2d(x, kernel_size=3, stride=2, ceil_mode=True, return_indices=True)
        x = self.feature_block3(x)
        dim3 = x.size()
        x, indices_3 = F.max_pool2d(x, kernel_size=3, stride=2, ceil_mode=True, return_indices=True)
        x = self.feature_block4(x)
        x = self.classifier_conv(x)

        pool_ind = [indices_1, indices_2, indices_3]
        dim_ind = [dim1, dim2, dim3]
        return x, dim_ind, pool_ind

class SqueezeSegNetDecoder(nn.Module):

    def __init__(self, out_channels):
        super(SqueezeSegNetDecoder, self).__init__()

        inverse_final_conv = nn.Conv2d(1000, 512, kernel_size=1)
        self.inverse_classifier_conv = nn.Sequential(
            inverse_final_conv,
            nn.ReLU(inplace=True),
        )

        self.inverse_feature_block4 = nn.Sequential(
            FireDec(512, 512, 256, 256),
        )

        self.inverse_feature_block3 = nn.Sequential(
            FireDec(512, 384, 256, 256),
            FireDec(384, 384, 192, 192),
            FireDec(384, 256, 192, 192),
            FireDec(256, 256, 128, 128),
        )

        self.inverse_feature_block2 = nn.Sequential(
            FireDec(256, 128, 128, 128),
            FireDec(128, 128, 64, 64),
            FireDec(128, 96, 64, 64),
        )

        self.inverse_feature_block1 = nn.Sequential(
            nn.ConvTranspose2d(96, out_channels, kernel_size=10, stride=2, padding=1),
        )
    def forward(self, x, dim_ind, pool_ind):
        x = self.inverse_classifier_conv(x)
        x = self.inverse_feature_block4(x)
        x = F.max_unpool2d(x, pool_ind[2], kernel_size=3, stride=2, output_size=dim_ind[2])
        x = self.inverse_feature_block3(x)
        x = F.max_unpool2d(x, pool_ind[1], kernel_size=3, stride=2, output_size=dim_ind[1])
        x = self.inverse_feature_block2(x)
        x = F.max_unpool2d(x, pool_ind[0], kernel_size=3, stride=2, output_size=dim_ind[0])
        #import pdb; pdb.set_trace()
        x = self.inverse_feature_block1(x)
        return x

Hi,
Have you solved this error? I am also facing somehow similar problem. Please share your experience. Thanks