MaxUnpooling in network class

I’m trying to construct a convolutional autoencoder but wrapped in a Class for ease. However, when I perform the MaxUnpooling in the decoder I get a “missing indices” error because as seen in many posts you should specify the indices from carrying out the MaxPooling in the encoder.

However I’m unsure as how to wrap this into the class…

class Autoencoder(torch.nn.Module):
               
    def __init__(self):
        super().__init__()
                
        self.encoder = torch.nn.Sequential(            
            torch.nn.Conv2d(1,64,3,padding=1), 
            torch.nn.ReLU(),
            torch.nn.Conv2d(64,128,3,padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(128,256,3,padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2))

        self.decoder = torch.nn.Sequential(    
            torch.nn.MaxUnpool2d(2),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(256,128,3,padding=1), 
            torch.nn.MaxUnpool2d(2),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(128,64,3,padding=1),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(64,1,3,padding=1),
            torch.nn.ReLU())

        
    def forward(self,x):
        features = self.encoder(x)
        output = self.decoder(features)
        return output

It seems inefficient to split the decoder up at every MaxUnpool during the forward process to call a separate unpooling function so is there a nice way to wrap the MaxUnpooling inside the class?

Edit: I found this post from last year so maybe this is still not possible -MaxUnpool2d with indices from MaxPool2d, all in nn.Sequential

Many thanks in advance for any advice!

You can rewrite the class or to create a new nn.Module which performs both.
The 2nd option would be better in order to be aligned with future upgrades

Many thanks for your reply! Is this the sort of thing you mean?

class Autoencoder(torch.nn.Module):
               
    def __init__(self):
        super().__init__()
                
            
        self.conv1 = torch.nn.Conv2d(1,64,3,padding=1) 
        self.conv2 = torch.nn.Conv2d(64,128,3,padding=1)
        self.conv3 = torch.nn.Conv2d(128,256,3,padding=1)
        self.relu = torch.nn.ReLU()
        self.mp = torch.nn.MaxPool2d(2,return_indices=True)
        self.up = torch.nn.MaxUnpool2d(2)
        self.trans1 = torch.nn.ConvTranspose2d(256,128,3,padding=1)
        self.trans2 = torch.nn.ConvTranspose2d(128,64,3,padding=1)
        self.trans3 = torch.nn.ConvTranspose2d(64,1,3,padding=1)

def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        x, ind1 = self.mp(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x, ind2 = self.mp(x)
        
        x = self.up(x, ind2)
        x = self.relu(x)
        x = self.trans1(x)
        x = self.up(x, ind1)
        x = self.relu(x)
        x = self.trans2(x)
        x = self.relu(x)
        x = self.trans3(x)
        x = self.relu(x)

        return x

It produces a shape error at x = self.up(x, ind1) but it might be working up to then…

I mean that you can write a custom nn.Module for both, encoder and decoder. You can return several outputs in decoder class and to have several inputs in forward. This way you can keep the packaged format you have.

Ah I understand… will that also help free up some memory because I wouldn’t be calling so many computations to the graph history?

It will call the same computations :slight_smile: (you cannot make part of the model to disappear )
But your state dict will has two main parts
model.encoder and model.decoder
It’s just a matter of readability and modularity

1 Like

Right right I understand, I’ll clean it up now! :slight_smile: thanks for the help

Note that the second way you built the model is also fine. You just will have everything together

Makes sense, especially if you then want to just call the output of the encoder after training to get the compressed data…so then you don’t have to write a separate model you can just access the training model in evaluation mode and ignore the decoder.