How to construct the decoder?


(oasjd7) #1

I tried to make decoder corresponding this encoder but it failed.
How can make the decoder?

class EncoderConv(nn.Module):
    def __init__(self):
        super(EncoderConv, self).__init__()
        self.extractor1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2),
        )
        self.extractor2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=1, padding=0),
            nn.ReLU(True),
        )

    def forward(self, x):
        x = self.extractor1(x)
        x = self.extractor2(x)
        x = x.view(params.batch_size, -1)
        return x


class DecoderConv(nn.Module):
    def __init__(self):
        super(DecoderConv, self).__init__()
        self.decoder1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=5, stride=1, padding=0),
            nn.ReLU(True),
            nn.MaxUnpool2d(kernel_size=2),

            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
        )
        self.decodr2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(1),
            nn.ReLU(True),
            nn.MaxUnpool2d(kernel_size=2),
        )

    def forward(self, x):
        x = self.decoder1(x)
        x = self.decoder2(x)
        return x


#2

How did you try to construct the decoder and what is not working?
Do you get any errors or do you need some input regarding some shapes?


(oasjd7) #3

I don’t know the position of BatchNorm2d , ReLU in decoder. and how to change MaxPool2d to decoder.
There is nn.MaxUnpool2d but many decoders I have seen don’t use this.