I have been trying to implement convolutional VAE in PyTorch for a while now and am somehow not able to correctly train my network. Here’s the encoder, decoder, and training loop. I am training the model on MNIST dataset.
Encoder output: Two tensors (loc, logvar) of shape [batch_size, latent_dims]
Decoder output: Image of shape [batch_size, 1, 28, 28]
Problem: Loss remains almost constant.
Encoder
class LocLogvar(nn.Module):
def __init__(self, in_features, latent_dims):
super(LocLogvar, self).__init__()
self.loc = nn.Linear(in_features, latent_dims)
self.logvar = nn.Linear(in_features, latent_dims)
def forward(self, inputs):
loc = self.loc(inputs)
logvar = self.logvar(inputs)
return loc, logvar
encoder = nn.Sequential(OrderedDict([
('e_conv_layer_1', nn.Conv2d(1, 16, 5, 1)), # 16 x 24 x 24
('e_relu_layer_1', IPLReLU()),
('e_batch_norm_1', nn.BatchNorm2d(16)),
('e_conv_layer_2', nn.Conv2d(16, 32, 5, 1)), # 32 x 20 x 20
('e_relu_layer_2', IPLReLU()),
('e_batch_norm_2', nn.BatchNorm2d(32)),
('e_conv_layer_3', nn.Conv2d(32, 32, 11, 1)), # 32 x 10 x 10
('e_relu_layer_3', IPLReLU()),
('e_batch_norm_3', nn.BatchNorm2d(32)),
('e_conv_layer_4', nn.Conv2d(32, 64, 5, 1)), # 64 x 6 x 6
('e_relu_layer_4', IPLReLU()),
('e_batch_norm_4', nn.BatchNorm2d(64)),
('e_dropout_layer_1', nn.Dropout2d(p=0.75)),
('e_conv_layer_5', nn.Conv2d(64, 128, 5, 1)), # 128 x 2 x 2
('e_relu_layer_5', IPLReLU()),
('e_batch_norm_5', nn.BatchNorm2d(128)),
('e_dropout_layer_1', nn.Dropout2d(p=0.85)),
('e_flatten_layer', nn.Flatten()),
('out_layer', LocLogvar(128*2*2, latent_dims))
]))
Decoder
class Reshape(nn.Module):
def __init__(self, *shape):
super(Reshape, self).__init__()
self.shape = shape
def forward(self, X):
return X.view(-1, *self.shape)
decoder = nn.Sequential(OrderedDict([
('inv_linear_layer_1', nn.Linear(latent_dims, 128*2*2)), # 128 * 2 * 2
('inv_relu_layer_5', IPLReLU()),
('inv_flatten_layer', Reshape(128, 2, 2)), # 128 x 2 x 2
('inv_conv_layer_5', nn.ConvTranspose2d(128, 64, 5, 1)), # 64 x 6 x 6
('inv_batch_norm_5', nn.BatchNorm2d(64)),
('inv_relu_layer_4', IPLReLU()),
('inv_conv_layer_4', nn.ConvTranspose2d(64, 32, 5, 1)), # 32 x 10 x 10
('inv_batch_norm_4', nn.BatchNorm2d(32)),
('inv_relu_layer_3', IPLReLU()),
('inv_conv_layer_3', nn.ConvTranspose2d(32, 32, 11, 1)), # 32 x 20 x 20
('inv_batch_norm_3', nn.BatchNorm2d(32)),
('inv_relu_layer_2', IPLReLU()),
('inv_conv_layer_2', nn.ConvTranspose2d(32, 16, 5, 1)), # 16 x 24 x 24
('inv_batch_norm_2', nn.BatchNorm2d(16)),
('inv_relu_layer_1', IPLReLU()),
('inv_conv_layer_1', nn.ConvTranspose2d(16, 1, 5, 1)), # 1 x 28 x 28
('inv_batch_norm_1', nn.BatchNorm2d(1)),
('out_layer', nn.Sigmoid())
]))
Loss function
def loss_fn(loc, logvar, reconstructed, img):
reconstructed = reconstructed.view(-1, 784)
img = img.view(-1, 784)
recon_loss = -torch.sum(
img*torch.log(reconstructed+1e-18) + (1-img)*torch.log(1-reconstructed+1e-18)
)
kl_loss = 0.5 * torch.sum(
-logvar - 1 + logvar.exp() + loc**2
)
return recon_loss + kl_loss
Training Loop
encoder = encoder.to(torch.device('cuda'))
decoder = decoder.to(torch.device('cuda'))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for i in range(epochs):
for idx in range(x_train.shape[0] // batch_size):
# Get the batch to train
x_batch = x_train[idx*batch_size:(idx+1)*batch_size, ...]
# Forward pass through encoder
loc, logvar = encoder(x_batch)
# Reparameterize
epsilon = torch.randn_like(loc)
z = loc + torch.exp(logvar * 0.5) * epsilon
# Forward pass through decoder
reconstructed = decoder(z)
loss = loss_fn(loc, logvar, reconstructed, x_batch)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Print the results
stdout.write(f'\r epoch : {i}\t'
f'step : {min((idx+1)*batch_size, x_train.shape[0])}/{x_train.shape[0]}\t'
f'loss : {loss.item():.3f}\t')
print()
Output
epoch : 0 step : 60000/60000 loss : 634711.188
epoch : 1 step : 60000/60000 loss : 635014.250
epoch : 2 step : 60000/60000 loss : 634935.625
epoch : 3 step : 60000/60000 loss : 635042.812
epoch : 4 step : 60000/60000 loss : 634130.562
epoch : 5 step : 60000/60000 loss : 634780.438
epoch : 6 step : 60000/60000 loss : 634427.250
epoch : 7 step : 60000/60000 loss : 634962.000
epoch : 8 step : 60000/60000 loss : 634118.750
epoch : 9 step : 60000/60000 loss : 636151.125
I suspect some gradients are getting detached somewhere but I am not sure what is causing this behaviour. It would be great if someone can suggest my mistake in the model