VAE training loss

Hey, im trying to implement a vae and the model is

class VAE(nn.Module):
    def __init__(self, z_dims=60):
        super(VAE, self).__init__()
        self.encoder = q_zx(z_dims=z_dims)
        self.decoder = p_xz(z_dims=z_dims, batch_size=batch_size)
        
    def encode(self, x):
        mu_z, logvar_z = self.encoder(x)
        # mu_z = (1, z_dims)
        # logvar_z = (1, z_dims)
        return mu_z, logvar_z
    
    def decode(self, x):
        mu_x, logvar_x = self.decoder(x)
        # mu_x = (batch, 1, 64, 64)
        # logvar_x = (batch, 1, 64, 64)
        return mu_x, logvar_x
    
    def sample(self, mu_z, logvar_z):
        std = torch.exp(0.5*logvar_z)
        eps = torch.randn_like(std)
        return mu_z + eps*std
    
    def loss_function(self, mu_x, logvar_x, x, mu_z, logvar_z):
        mu_x = torch.flatten(mu_x, start_dim = 1)
        logvar_x = torch.flatten(logvar_x, start_dim=1)
        loss_rec = torch.mean(-torch.sum(
        (-0.5 * np.log(2.0 * np.pi))
        + (-0.5 * logvar_x)
        + ((-0.5 / torch.exp(logvar_x)) * (x.view(-1, 64*64) - mu_x) ** 2.0),
        dim=1,
        ))
        KLD = torch.mean(-0.5 * torch.sum(1 + logvar_z - mu_z.pow(2) - logvar_z.exp()))
        return loss_rec + KLD
    
    def forward(self, x):
        mu_z, logvar_z = self.encode(x)
        z_sampled = self.sample(mu_z, logvar_z)
        mu_x, logvar_x = self.decode(z_sampled)
        return mu_z, logvar_z, mu_x, logvar_x

The dimensions of mu_z, logvar_z, mu_x, logvar_x is written in the encode and decode func
The train function i use is

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        mu_z, logvar_z, mu_x, logvar_x = model(data)
        loss = model.loss_function(mu_x, logvar_x, data, mu_z, logvar_z)
        loss.backward()
        train_loss += loss.mean().item()
        optimizer.step()
        if batch_idx % 1200 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),
                       loss.item() / len(data)))```

However my loss in nan
why is that happening??

Thanks

It’s because arg for logvar becomes big negative value and exp(logvar) goes to near 0.
The common trick is clip logvar above some value (e.g. -5) or exp(logvar) above small positive value (e.g. 1e-7).

Ah thanks, can you provide code on how to do that :sweat_smile:
thanks

In your loss_function(…), like …

var_x = torch.clip(torch.exp(logvar_x), min=1e-5)
loss_rec = torch.mean(-torch.sum(
        (-0.5 * np.log(2.0 * np.pi))
        + (-0.5 * logvar_x)
        + ((-0.5 / var_x) * (x.view(-1, 64*64) - mu_x) ** 2.0),
        dim=1,
        ))

By the way, if a vae model learn to std in decoder, it tend to go 0 and become deterministic (so, usually fixed to some small value, like 0.1). See Taming VAE paper by Rezende and Viola for more discussion about this ([1810.00597] Taming VAEs).

2 Likes