Autoencoder testing encoder output

Hi PyTorch users,

I’m training an autoencoder, but I’m interested in the output of the encoder (so the small linear layer in the middle) for my test values.

Is there an efficient way of testing the encoder, once trained, separately rather than testing on the full network and pulling out the encoders output values?

If not… how does one pull out the encoders output values?

My autoencoder looks as follows:

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(4, 32),
            nn.ReLU(True), 
            nn.Linear(32, 12), 
            nn.ReLU(True), 
            nn.Linear(12, 5))
        self.decoder = nn.Sequential(
            nn.Linear(5, 12),
            nn.ReLU(True),
            nn.Linear(12, 32),
            nn.ReLU(True), 
            nn.Linear(32, 4), 
            )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

It may help to know that I save my models as follows:

torch.save(model, save_path + 'autoencode_epoch_'+str(epoch))

I think this is probably a simple problem but it would be nice to see how people think this is best done… thanks in advance!

EDIT: Please ignore the fact I scale a 4 parameter input up and then encode it into 5 variables…this is just a test case.

Since your forward method just combined the encoder with the decoder, you could use this code:

model = autoencoder()
x = torch.randn(1, 4)
enc_output = model.encoder(x)

Of course, this wouldn’t work, if your model applies some other calls inside forward.

Another approach would be to use forward hooks to get the desired output.
Here is an example for a UNet model. However, you would call the complete forward pass and just store the activations from the encoder.

2 Likes

That works perfectly! I’ll make sure to keep the forward pass restricted to an encoder and decoder to avoid any problems.

I’m unfamiliar with the use of hooks but from your code in the link and your description here I get the idea…it will only be necessary to use the hooks method if the problem requires using more in the forward pass I assume.

Many thanks as always!

1 Like