I am wondering why the latent parameters in my VAE converge to some values during training and stay constant regardless of the input image. When initialized, the latent parameters are different as expected. The VAE can reconstruct images decently, shown are the samples (top) and reconstructions (bottom) of some sprites in the DSprites dataset.
I originally thought I was passing in the same data to get the same latent parameters, but the reconstructions show that I am passing in different images.
I have noticed that regardless of what I change the latent parameters to, as long as I keep the indices and sizes from the maxpool layers consistent, the output of the VAE will reconstruct the input. My model has a MaxUnpool layer which requires the above indices and sizes.
Here is the code for my VAE.
class Reparameterize(nn.Module):
def __init__(self, h_dim):
super().__init__()
self.h_dim = h_dim
def forward(self, x):
mu, logvar = x[:, :self.h_dim], x[:, self.h_dim:]
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
sample = mu + std * eps
return sample, mu, logvar
class BVAE(nn.Module):
def __init__(self, h_dim, enc_conv, enc_lin, dec_lin, dec_conv, in_layers):
super().__init__()
self.h_dim = h_dim
self.enc_conv = enc_conv
self.enc_lin = enc_lin
self.reparam = Reparameterize(h_dim)
self.dec_lin = dec_lin
self.dec_conv = dec_conv
self.enc_lin_input = in_layers[0]
self.dec_conv_input = [in_layers[1][1], in_layers[1][2], in_layers[1][3]]
def encode(self, x):
x = x
indices = []
sizes = []
if type(self.enc_conv) == ModuleList:
indice = 0
for l in self.enc_conv:
if type(l) == ModuleList:
for lay in l:
x = lay(x)
else:
result = l(x)
sizes.append(x.size())
x, indice = result
indices.append(indice)
else:
x = self.enc_conv(x)
x = x.view(-1, self.enc_lin_input)
for l in self.enc_lin:
x = l(x)
return x, indices, sizes
def decode(self, sample, indices, sizes):
x = sample
for l in self.dec_lin:
x = l(x)
x = x.view(-1, *self.dec_conv_input)
for l in self.dec_conv:
if type(l) == ModuleList:
for lay in l:
x = lay(x)
else:
indice = indices.pop()
s = sizes.pop()
x = l(x, indices=indice, output_size=s)
return x
def forward(self, x):
x_size = x.size()
x, indices, sizes = self.encode(x)
sample, mu, logvar = self.reparam(x)
x = self.decode(sample, indices, sizes)
x = x.view(x_size)
return x, mu, logvar
def bce_loss(self, reconstruction, x):
criterion = nn.MSELoss(reduction='mean')
bce_loss = criterion(reconstruction, x)
return bce_loss
Does anyone have thoughts about why this happens?
Edit: Attached code for calculating KL-divergence (loss)
def kl_div(h_dim, sample, q=[], hp=[], require_grad=True, individual_elements=False,
device='cpu'):
"""
Calculates kl-divergence for multiple scenarios:
I. No q (prior) passed in:
Assume unit Gaussian prior
a. If require_grad:
Calculate kl-divergence manually
i. If individual_elements:
Return kl-divergence for each latent distribution
ii. Else:
Calculates kl-divergence averaged across distributions
and batch
1. If hp passed in:
Return Beta-VAE kl-loss
2. Else:
Return averaged loss
b. else:
Use torch kl_divergence with torch Normal distributions
i. ii. same as for Ia.
II. q (prior) passed in:
Return torch kl_divergence with torch Normal distributions
"""
if len(sample.size()) == 1:
sample = sample.unsqueeze(0)
mu, logvar = sample[:, :h_dim], sample[:, h_dim:]
d1 = Normal(mu, torch.sqrt(torch.exp(logvar)))
if len(q) == 0:
k = 0
if require_grad:
k = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
else:
muq, logvarq = torch.zeros(mu.size()).to(device), torch.ones(logvar.size()).to(device)
d2 = Normal(muq, logvarq)
k = kl_divergence(d1, d2)
if individual_elements:
return k
if len(hp) == 0:
kl = torch.mean(k, dim=1)
kld = torch.mean(kl, dim=0)
return kld
else:
gamma, c = hp
kl = torch.mean(k, dim=1)
kld = torch.mean(kl-c, dim=0)
return gamma * torch.abs(kld)
elif type(q) == torch.tensor:
if (q.size()) == 1:
q = q.unsqueeze(0)
q = q.to(device)
muq, logvarq = q[:, :h_dim], q[:, h_dim:]
d2 = Normal(muq, logvarq)
kld = kl_divergence(d1, d2)
kld = torch.mean(kld, dim=-1)
return kld
else:
raise ValueError("q must be tensor [batch, mu/logvar]")