VAE train vs test mode

Hello,

I’m training a VAE and I’m a bit confused as to what I need to do during testing.

Based on what I’ve read, it seems that during training we use the reparam trick (as per the Kingma paper) and then during testing we simply sample from epsilon, ie N(0,1). This is how I’m trying to implement this:

    def train_reparam(self, h):
        mean, logvar = self.fc1a(h), self.fc1b(h)  
        std_dev = torch.exp(0.5*logvar)
        eps = torch.randn_like(std_dev)        
        z = mean + std_dev * eps
        return z, mean, logvar
    
    def test_reparam(self, h):
        mean, logvar = self.fc1a(h), self.fc1b(h)  
        std_dev = torch.exp(0.5*logvar)
        eps = torch.randn_like(std_dev)        
        z = eps # z is now eps (norm dist)
        return z, mean, logvar

    def forward(self, x):
        h = self.encoder(x)
        if self.training:
            z, mean, logvar = self.train_reparam(h)
        else:
            z, mean, logvar = self.test_reparam(h)
        z = self.fc2(z)
        return self.decoder(z), mean, logvar

The results obtained with the above, however, as confusing. The test loss seems to be lower than during training and the quality of images generated during testing is poor.

What also confuses me is that when I look for examples of VAE code, many of them don’t seem to operate differently during testing (ie they reparam at train and test) - even the PyTorch VAE example !

So my questions are:

  1. Is what I’m doing correct? Do I need this train/test distinction?
  2. If so, should I return z, mean & logvar during test or should that be replaced with z, 0, 1?

Any help would be greatly appreciated!

Thanks,
NK

Firstly the pytroch version of VAE works fine.
Regarding different loss here is some of the possibilities

  1. Different Batch Sizes in test vs train
  2. Different classes you might generate from VAE sample during testing and training.

Hi!

I’ve used the Pytorch VAE model before (and it did indeed work well!) but the Pytorch architecture doesn’t use the whole train/test idea mentioned by Kingma. After learning about the non-reparam pass during testing I don’t understand why that is the case.

Is the Kingma solution the “right way” to do things but we choose to do things the “wrong way”…?

Thanks,
NK

Firstly I am not pretty sure bout your conclusion that “Pytorch architecture doesn’t use the whole train/test idea mentioned by Kingma” because pytorch version runs for 10 epochs and 128 batch size each and it use the complete dataset for 10 epochs.

Secondly the reparam trick should be use only during the training not during the test. As you can think our entire goal of AE or VAE is to learn the latent space which is a correct representation of our data distribution (clustering or keeping close the same classes together ) and once you finish the training you can sample anything from the latent space and you will get your good samples

What I meant was that in the Pytorch VAE code they use the reparam trick, ie z=mean + eps*logvar, during train AND test (instead of that during train and z=eps during test) - or am I missing something?

So is the code below a correct replacement for reparam during test (so basically correct sampling)? I’m returning z=eps (so N(0,1)) and I guess the mean&logvar should be calculated like before? (and then used for KL divergence)

def test_reparam(self, h): 
   mean, logvar = self.fc1a(h), self.fc1b(h) 
   std_dev = torch.exp(0.5*logvar) 
   eps = torch.randn_like(std_dev) 
   z = eps # z is now eps (norm dist) 
   return z, mean, logvar

Thanks :slight_smile:

Or does z becomes equal to eps “automatically”, eg thanks to model.eval() or something else happening “behind the scenes”, and we don’t have to implement train and test passes ourselves…?

To answer your first question about calculating the loss during the test and train.
They are just doing it because of the way to check the error produced for the test and to compare with true image in the test set. But nothing will happen because of that as you are not doing loss.backward() so no model parameters will be updated.

You can collect mu and logvar as needed. But as I mentioned previously you only pass the z (“latent vector”) to the decoder which will give you the output image. So only z is very important.

SO even you pass a random noise to decoder it will generate the image

Ok, I feel like I understand all the loss stuff now :slight_smile:

I’m still confused about z though. Do I have to define a separate case for training and testing, eg

if self.training:
   z = mean + logvar*eps # (1)
else:
   z = eps

and then in the forward pass call self.decoder(z)

OR

can I just use the line marked with (1) and Pytorch will read it as z = eps when running in .eval() mode?

Thanks a lot for your help!

Logically you don’t need any separate case. You can just use the code block given in the pytorch code. (Depends on what you want. If I want to send some sample I will use something like this)
Method I:

 zsample = torch.randn(64, 20).to(device)
 zsample = model.decode(zsample).cpu()

Method II:

(The reparam function only comes into picture if you pass something as a Input to encoder then it will generate z,mu,logavar using reparm trick then it will take the z and pass to decoder to produce output image). (Usually people don’t do this they just just take decoder and pass what they want)