Hi everyone!
I’m using a Variational Autoencoder and this is my implementation for the loss function:
class VariationalAutoencoder(nn.Module):
# ...some functions...
def gaussian_likelihood(self, x_hat, logscale, x):
scale = torch.exp(logscale)
mean = x_hat
dist = torch.distributions.Normal(mean, scale)
# measure prob of seeing image under p(x|z)
log_pxz = dist.log_prob(x)
return log_pxz.sum(dim=(1, 2, 3))
def forward(self, input):
mu, logvar = self.encode(input)
z = self.reparameterise(mu, logvar)
return self.decoder(z), mu, logvar, z
def loss_function(self, x_hat, x, mu, logvar, β=1):
std = torch.exp(logvar / 2)
q = torch.distributions.Normal(mu, std)
z = q.rsample()
# reconstruction loss
recon_loss = self.gaussian_likelihood(x_hat, self.log_scale, x)
# kl
kl = self.kl_divergence(z, mu, std)
# elbo
elbo = (kl - recon_loss)
elbo = elbo.mean()
return elbo
def kl_divergence(self, z, mu, std):
# --------------------------
# Monte carlo KL divergence
# --------------------------
# 1. define the first two probabilities (in this case Normal for both)
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
q = torch.distributions.Normal(mu, std)
# 2. get the probabilities from the equation
log_qzx = q.log_prob(z)
log_pz = p.log_prob(z)
# kl
kl = (log_qzx - log_pz)
kl = kl.sum(-1)
return kl
I use Sigmoid() function when exiting the decoder. I train the model in this way:
for epoch in range(0, epochs + 1):
if epoch > 0: # test untrained net first
model.train()
train_loss = 0
loop = tqdm(train_loader)
optimizer = model.setOptimizer(model)
for x in loop:
x = x.to(device)
x_hat, mu, logvar, features = model(x)
loss = model.loss_function(x_hat, x, mu, logvar)
train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
loop.set_postfix(loss=loss)
train_loss = train_loss /= len(train_loader.dataset)
print(f'====> Epoch: {epoch} Average loss: {train_loss:.4f}')
The loss doesn’t settle at zero but but becomes negative (around -2). If I remove train_loss = train_loss /= len(train_loader.dataset)
, it’s diverges completely.
How can I make the loss converge to zero?