I have created VAE based on ResNet backbone (Note: In all tensors below, the first dimension is batch).
As a loss function, I use:
kldLoss = (-0.5 * torch.sum(1 + logVar2 - mu.pow(2) - logVar2.exp(), dim=1)).mean()
decLoss = mseLoss(decFromInput, target)
loss = kldLoss + decLoss
where decoded
is output from decoder, mu
and logVar2
are output from encoder before reparametrize trick. I have also tried to use BCE instead of MSE, but the outcome was the same.
Optimizer is:
`optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0002)`
To reparametrize I use the standard trick:
std = torch.exp(0.5 * logVar2)
epsilon = self.normDist.sample(std.shape).to(std.device)
z = mu + std * epsilon
where self.normDist
is torch.distributions.Normal(0, 1)
and z
is the final output from encoder (eg. the latent space representation)
When I run the network on MNIST 1x64x64 after several iteration, the decoded output is incorrect, just blob (the output is basically the same since the first epoch). I use training batch size 64, full precision (no AMP). The training loop I use is a standard backprop after each batch, no gradient accumulation.
The loss value is more or less oscilating and not converge.
I have also sometimes obtained NaNs in outputs and the loss result. However, I cannot reproduce this since it happens randomly.
The torchinfo summary of my encoder - decoder is
=======================================================================================================
Layer (type (var_name):depth-idx) Input Shape Output Shape
=========================================================================================================
VAEModel (VAEModel) [1, 1, 64, 64] [1, 1, 64, 64]
├─VAEEncoder (modelEnc): 1-1 [1, 1, 64, 64] [1, 128]
│ └─Sequential (conv1): 2-1 [1, 1, 64, 64] [1, 128, 64, 64]
│ │ └─Conv2d (0): 3-1 [1, 1, 64, 64] [1, 128, 64, 64]
│ │ └─BatchNorm2d (1): 3-2 [1, 128, 64, 64] [1, 128, 64, 64]
│ │ └─ReLU (2): 3-3 [1, 128, 64, 64] [1, 128, 64, 64]
│ └─ModuleList (layers): 2-2 -- --
│ │ └─Sequential (0): 3-4 [1, 128, 64, 64] [1, 128, 32, 32]
│ │ │ └─ResNetBlock (0): 4-1 [1, 128, 64, 64] [1, 128, 64, 64]
│ │ │ └─ResNetBlock (1): 4-2 [1, 128, 64, 64] [1, 128, 64, 64]
│ │ │ └─Sequential (2): 4-3 [1, 128, 64, 64] [1, 128, 32, 32]
│ │ └─Sequential (1): 3-5 [1, 128, 32, 32] [1, 256, 16, 16]
│ │ │ └─ResNetBlock (0): 4-4 [1, 128, 32, 32] [1, 128, 32, 32]
│ │ │ └─ResNetBlock (1): 4-5 [1, 128, 32, 32] [1, 256, 32, 32]
│ │ │ └─Sequential (2): 4-6 [1, 256, 32, 32] [1, 256, 16, 16]
│ │ └─Sequential (2): 3-6 [1, 256, 16, 16] [1, 512, 16, 16]
│ │ │ └─ResNetBlock (0): 4-7 [1, 256, 16, 16] [1, 256, 16, 16]
│ │ │ └─ResNetBlock (1): 4-8 [1, 256, 16, 16] [1, 512, 16, 16]
│ │ │ └─Identity (2): 4-9 [1, 512, 16, 16] [1, 512, 16, 16]
│ │ └─Sequential (3): 3-7 [1, 512, 16, 16] [1, 512, 16, 16]
│ │ │ └─ResNetBlock (0): 4-10 [1, 512, 16, 16] [1, 512, 16, 16]
│ │ │ └─ResNetBlock (1): 4-11 [1, 512, 16, 16] [1, 512, 16, 16]
│ └─Sequential (toLatentImg): 2-3 [1, 512, 16, 16] [1, 6, 16, 16]
│ │ └─BatchNorm2d (0): 3-8 [1, 512, 16, 16] [1, 512, 16, 16]
│ │ └─ReLU (1): 3-9 [1, 512, 16, 16] [1, 512, 16, 16]
│ │ └─Conv2d (2): 3-10 [1, 512, 16, 16] [1, 512, 16, 16]
│ │ └─Conv2d (3): 3-11 [1, 512, 16, 16] [1, 6, 16, 16]
│ └─Linear (fc_mu): 2-4 [1, 1536] [1, 128]
│ └─Linear (fc_logvar2): 2-5 [1, 1536] [1, 128]
├─VAEDecoder (modelDec): 1-2 [1, 128] [1, 1, 64, 64]
│ └─Sequential (fromLatent): 2-6 [1, 128] [1, 6, 16, 16]
│ │ └─Linear (0): 3-12 [1, 128] [1, 1536]
│ │ └─View (1): 3-13 [1, 1536] [1, 6, 16, 16]
│ └─Sequential (fromLatentImg): 2-7 [1, 6, 16, 16] [1, 512, 16, 16]
│ │ └─Conv2d (0): 3-14 [1, 6, 16, 16] [1, 512, 16, 16]
│ │ └─Conv2d (1): 3-15 [1, 512, 16, 16] [1, 512, 16, 16]
│ └─ModuleList (layers): 2-8 -- --
│ │ └─Sequential (0): 3-16 [1, 512, 16, 16] [1, 512, 16, 16]
│ │ │ └─ResNetBlock (0): 4-12 [1, 512, 16, 16] [1, 512, 16, 16]
│ │ │ └─ResNetBlock (1): 4-13 [1, 512, 16, 16] [1, 512, 16, 16]
│ │ │ └─Identity (2): 4-14 [1, 512, 16, 16] [1, 512, 16, 16]
│ │ └─Sequential (1): 3-17 [1, 512, 16, 16] [1, 256, 32, 32]
│ │ │ └─ResNetBlock (0): 4-15 [1, 512, 16, 16] [1, 512, 16, 16]
│ │ │ └─ResNetBlock (1): 4-16 [1, 512, 16, 16] [1, 256, 16, 16]
│ │ │ └─Sequential (2): 4-17 [1, 256, 16, 16] [1, 256, 32, 32]
│ │ └─Sequential (2): 3-18 [1, 256, 32, 32] [1, 128, 64, 64]
│ │ │ └─ResNetBlock (0): 4-18 [1, 256, 32, 32] [1, 256, 32, 32]
│ │ │ └─ResNetBlock (1): 4-19 [1, 256, 32, 32] [1, 128, 32, 32]
│ │ │ └─Sequential (2): 4-20 [1, 128, 32, 32] [1, 128, 64, 64]
│ └─Sequential (conv1): 2-9 [1, 128, 64, 64] [1, 1, 64, 64]
│ │ └─BatchNorm2d (0): 3-19 [1, 128, 64, 64] [1, 128, 64, 64]
│ │ └─ReLU (1): 3-20 [1, 128, 64, 64] [1, 128, 64, 64]
│ │ └─Conv2d (2): 3-21 [1, 128, 64, 64] [1, 1, 64, 64]
=========================================================================================================
ResNetBlock is the usual ResNet block with ReLU, Conv2d and BatchNorm2d modules.
(Note: the image is rotated to save space)