I’m training a VAE and I’m a bit confused as to what I need to do during testing.
Based on what I’ve read, it seems that during training we use the reparam trick (as per the Kingma paper) and then during testing we simply sample from epsilon, ie N(0,1). This is how I’m trying to implement this:
def train_reparam(self, h):
mean, logvar = self.fc1a(h), self.fc1b(h)
std_dev = torch.exp(0.5*logvar)
eps = torch.randn_like(std_dev)
z = mean + std_dev * eps
return z, mean, logvar
def test_reparam(self, h):
mean, logvar = self.fc1a(h), self.fc1b(h)
std_dev = torch.exp(0.5*logvar)
eps = torch.randn_like(std_dev)
z = eps # z is now eps (norm dist)
return z, mean, logvar
def forward(self, x):
h = self.encoder(x)
if self.training:
z, mean, logvar = self.train_reparam(h)
else:
z, mean, logvar = self.test_reparam(h)
z = self.fc2(z)
return self.decoder(z), mean, logvar
The results obtained with the above, however, as confusing. The test loss seems to be lower than during training and the quality of images generated during testing is poor.
What also confuses me is that when I look for examples of VAE code, many of them don’t seem to operate differently during testing (ie they reparam at train and test) - even the PyTorch VAE example !
So my questions are:
Is what I’m doing correct? Do I need this train/test distinction?
If so, should I return z, mean & logvar during test or should that be replaced with z, 0, 1?
I’ve used the Pytorch VAE model before (and it did indeed work well!) but the Pytorch architecture doesn’t use the whole train/test idea mentioned by Kingma. After learning about the non-reparam pass during testing I don’t understand why that is the case.
Is the Kingma solution the “right way” to do things but we choose to do things the “wrong way”…?
Firstly I am not pretty sure bout your conclusion that “Pytorch architecture doesn’t use the whole train/test idea mentioned by Kingma” because pytorch version runs for 10 epochs and 128 batch size each and it use the complete dataset for 10 epochs.
Secondly the reparam trick should be use only during the training not during the test. As you can think our entire goal of AE or VAE is to learn the latent space which is a correct representation of our data distribution (clustering or keeping close the same classes together ) and once you finish the training you can sample anything from the latent space and you will get your good samples
What I meant was that in the Pytorch VAE code they use the reparam trick, ie z=mean + eps*logvar, during train AND test (instead of that during train and z=eps during test) - or am I missing something?
So is the code below a correct replacement for reparam during test (so basically correct sampling)? I’m returning z=eps (so N(0,1)) and I guess the mean&logvar should be calculated like before? (and then used for KL divergence)
def test_reparam(self, h):
mean, logvar = self.fc1a(h), self.fc1b(h)
std_dev = torch.exp(0.5*logvar)
eps = torch.randn_like(std_dev)
z = eps # z is now eps (norm dist)
return z, mean, logvar
Or does z becomes equal to eps “automatically”, eg thanks to model.eval() or something else happening “behind the scenes”, and we don’t have to implement train and test passes ourselves…?
To answer your first question about calculating the loss during the test and train.
They are just doing it because of the way to check the error produced for the test and to compare with true image in the test set. But nothing will happen because of that as you are not doing loss.backward() so no model parameters will be updated.
You can collect mu and logvar as needed. But as I mentioned previously you only pass the z (“latent vector”) to the decoder which will give you the output image. So only z is very important.
SO even you pass a random noise to decoder it will generate the image
Logically you don’t need any separate case. You can just use the code block given in the pytorch code. (Depends on what you want. If I want to send some sample I will use something like this)
Method I:
(The reparam function only comes into picture if you pass something as a Input to encoder then it will generate z,mu,logavar using reparm trick then it will take the z and pass to decoder to produce output image). (Usually people don’t do this they just just take decoder and pass what they want)
Hello, I know this post from years before. I think the issue is simply you don’t use eps during testing, instead, you should use mean. A sample code from my code looks like this: