Hey, im trying to implement a vae and the model is
class VAE(nn.Module):
def __init__(self, z_dims=60):
super(VAE, self).__init__()
self.encoder = q_zx(z_dims=z_dims)
self.decoder = p_xz(z_dims=z_dims, batch_size=batch_size)
def encode(self, x):
mu_z, logvar_z = self.encoder(x)
# mu_z = (1, z_dims)
# logvar_z = (1, z_dims)
return mu_z, logvar_z
def decode(self, x):
mu_x, logvar_x = self.decoder(x)
# mu_x = (batch, 1, 64, 64)
# logvar_x = (batch, 1, 64, 64)
return mu_x, logvar_x
def sample(self, mu_z, logvar_z):
std = torch.exp(0.5*logvar_z)
eps = torch.randn_like(std)
return mu_z + eps*std
def loss_function(self, mu_x, logvar_x, x, mu_z, logvar_z):
mu_x = torch.flatten(mu_x, start_dim = 1)
logvar_x = torch.flatten(logvar_x, start_dim=1)
loss_rec = torch.mean(-torch.sum(
(-0.5 * np.log(2.0 * np.pi))
+ (-0.5 * logvar_x)
+ ((-0.5 / torch.exp(logvar_x)) * (x.view(-1, 64*64) - mu_x) ** 2.0),
dim=1,
))
KLD = torch.mean(-0.5 * torch.sum(1 + logvar_z - mu_z.pow(2) - logvar_z.exp()))
return loss_rec + KLD
def forward(self, x):
mu_z, logvar_z = self.encode(x)
z_sampled = self.sample(mu_z, logvar_z)
mu_x, logvar_x = self.decode(z_sampled)
return mu_z, logvar_z, mu_x, logvar_x
The dimensions of mu_z, logvar_z, mu_x, logvar_x is written in the encode and decode func
The train function i use is
def train(epoch):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
mu_z, logvar_z, mu_x, logvar_x = model(data)
loss = model.loss_function(mu_x, logvar_x, data, mu_z, logvar_z)
loss.backward()
train_loss += loss.mean().item()
optimizer.step()
if batch_idx % 1200 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))```
However my loss in nan
why is that happening??
Thanks