VAE with ResNet backbone not train

I have created VAE based on ResNet backbone (Note: In all tensors below, the first dimension is batch).

As a loss function, I use:

    kldLoss = (-0.5 * torch.sum(1 + logVar2 - mu.pow(2) - logVar2.exp(), dim=1)).mean()
    decLoss = mseLoss(decFromInput, target)
    loss = kldLoss + decLoss 

where decoded is output from decoder, mu and logVar2 are output from encoder before reparametrize trick. I have also tried to use BCE instead of MSE, but the outcome was the same.

Optimizer is:

`optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0002)`

To reparametrize I use the standard trick:

    std = torch.exp(0.5 * logVar2)        
    epsilon = self.normDist.sample(std.shape).to(std.device)
    z = mu + std * epsilon

where self.normDist is torch.distributions.Normal(0, 1) and z is the final output from encoder (eg. the latent space representation)

When I run the network on MNIST 1x64x64 after several iteration, the decoded output is incorrect, just blob (the output is basically the same since the first epoch). I use training batch size 64, full precision (no AMP). The training loop I use is a standard backprop after each batch, no gradient accumulation.

The loss value is more or less oscilating and not converge.

I have also sometimes obtained NaNs in outputs and the loss result. However, I cannot reproduce this since it happens randomly.

The torchinfo summary of my encoder - decoder is

=======================================================================================================
    Layer (type (var_name):depth-idx)                       Input Shape               Output Shape
    =========================================================================================================
    VAEModel (VAEModel)                                     [1, 1, 64, 64]            [1, 1, 64, 64]
    ├─VAEEncoder (modelEnc): 1-1                            [1, 1, 64, 64]            [1, 128]
    │    └─Sequential (conv1): 2-1                          [1, 1, 64, 64]            [1, 128, 64, 64]
    │    │    └─Conv2d (0): 3-1                             [1, 1, 64, 64]            [1, 128, 64, 64]
    │    │    └─BatchNorm2d (1): 3-2                        [1, 128, 64, 64]          [1, 128, 64, 64]
    │    │    └─ReLU (2): 3-3                               [1, 128, 64, 64]          [1, 128, 64, 64]
    │    └─ModuleList (layers): 2-2                         --                        --
    │    │    └─Sequential (0): 3-4                         [1, 128, 64, 64]          [1, 128, 32, 32]
    │    │    │    └─ResNetBlock (0): 4-1                   [1, 128, 64, 64]          [1, 128, 64, 64]
    │    │    │    └─ResNetBlock (1): 4-2                   [1, 128, 64, 64]          [1, 128, 64, 64]
    │    │    │    └─Sequential (2): 4-3                    [1, 128, 64, 64]          [1, 128, 32, 32]
    │    │    └─Sequential (1): 3-5                         [1, 128, 32, 32]          [1, 256, 16, 16]
    │    │    │    └─ResNetBlock (0): 4-4                   [1, 128, 32, 32]          [1, 128, 32, 32]
    │    │    │    └─ResNetBlock (1): 4-5                   [1, 128, 32, 32]          [1, 256, 32, 32]
    │    │    │    └─Sequential (2): 4-6                    [1, 256, 32, 32]          [1, 256, 16, 16]
    │    │    └─Sequential (2): 3-6                         [1, 256, 16, 16]          [1, 512, 16, 16]
    │    │    │    └─ResNetBlock (0): 4-7                   [1, 256, 16, 16]          [1, 256, 16, 16]
    │    │    │    └─ResNetBlock (1): 4-8                   [1, 256, 16, 16]          [1, 512, 16, 16]
    │    │    │    └─Identity (2): 4-9                      [1, 512, 16, 16]          [1, 512, 16, 16]
    │    │    └─Sequential (3): 3-7                         [1, 512, 16, 16]          [1, 512, 16, 16]
    │    │    │    └─ResNetBlock (0): 4-10                  [1, 512, 16, 16]          [1, 512, 16, 16]
    │    │    │    └─ResNetBlock (1): 4-11                  [1, 512, 16, 16]          [1, 512, 16, 16]
    │    └─Sequential (toLatentImg): 2-3                    [1, 512, 16, 16]          [1, 6, 16, 16]
    │    │    └─BatchNorm2d (0): 3-8                        [1, 512, 16, 16]          [1, 512, 16, 16]
    │    │    └─ReLU (1): 3-9                               [1, 512, 16, 16]          [1, 512, 16, 16]
    │    │    └─Conv2d (2): 3-10                            [1, 512, 16, 16]          [1, 512, 16, 16]
    │    │    └─Conv2d (3): 3-11                            [1, 512, 16, 16]          [1, 6, 16, 16]
    │    └─Linear (fc_mu): 2-4                              [1, 1536]                 [1, 128]
    │    └─Linear (fc_logvar2): 2-5                         [1, 1536]                 [1, 128]
    ├─VAEDecoder (modelDec): 1-2                            [1, 128]                  [1, 1, 64, 64]
    │    └─Sequential (fromLatent): 2-6                     [1, 128]                  [1, 6, 16, 16]
    │    │    └─Linear (0): 3-12                            [1, 128]                  [1, 1536]
    │    │    └─View (1): 3-13                              [1, 1536]                 [1, 6, 16, 16]
    │    └─Sequential (fromLatentImg): 2-7                  [1, 6, 16, 16]            [1, 512, 16, 16]
    │    │    └─Conv2d (0): 3-14                            [1, 6, 16, 16]            [1, 512, 16, 16]
    │    │    └─Conv2d (1): 3-15                            [1, 512, 16, 16]          [1, 512, 16, 16]
    │    └─ModuleList (layers): 2-8                         --                        --
    │    │    └─Sequential (0): 3-16                        [1, 512, 16, 16]          [1, 512, 16, 16]
    │    │    │    └─ResNetBlock (0): 4-12                  [1, 512, 16, 16]          [1, 512, 16, 16]
    │    │    │    └─ResNetBlock (1): 4-13                  [1, 512, 16, 16]          [1, 512, 16, 16]
    │    │    │    └─Identity (2): 4-14                     [1, 512, 16, 16]          [1, 512, 16, 16]
    │    │    └─Sequential (1): 3-17                        [1, 512, 16, 16]          [1, 256, 32, 32]
    │    │    │    └─ResNetBlock (0): 4-15                  [1, 512, 16, 16]          [1, 512, 16, 16]
    │    │    │    └─ResNetBlock (1): 4-16                  [1, 512, 16, 16]          [1, 256, 16, 16]
    │    │    │    └─Sequential (2): 4-17                   [1, 256, 16, 16]          [1, 256, 32, 32]
    │    │    └─Sequential (2): 3-18                        [1, 256, 32, 32]          [1, 128, 64, 64]
    │    │    │    └─ResNetBlock (0): 4-18                  [1, 256, 32, 32]          [1, 256, 32, 32]
    │    │    │    └─ResNetBlock (1): 4-19                  [1, 256, 32, 32]          [1, 128, 32, 32]
    │    │    │    └─Sequential (2): 4-20                   [1, 128, 32, 32]          [1, 128, 64, 64]
    │    └─Sequential (conv1): 2-9                          [1, 128, 64, 64]          [1, 1, 64, 64]
    │    │    └─BatchNorm2d (0): 3-19                       [1, 128, 64, 64]          [1, 128, 64, 64]
    │    │    └─ReLU (1): 3-20                              [1, 128, 64, 64]          [1, 128, 64, 64]
    │    │    └─Conv2d (2): 3-21                            [1, 128, 64, 64]          [1, 1, 64, 64]
    =========================================================================================================

ResNetBlock is the usual ResNet block with ReLU, Conv2d and BatchNorm2d modules.

VAE_2023_09_21_18_55_train_18_1

(Note: the image is rotated to save space)

This line was the problem. kldLoss is way too high compared to MSE (eg. 30 vs 0.02). I have added scale to kldLoss and the network converges correctly.

2 Likes