Vanishing Gradient in Autoecnoders

Hello!

I am trying to train a convolutional autoencoder to classify signals based on their shapes. My dataset consist of sinusoidal signals which have two different frequencies but their phases are randomized and my goal is to classify them based on the two periods. The autoencoder is based on someone else’s code who worked on the similar project. The network model is the following:

def warn(*args, **kwargs):
    pass

import warnings
warnings.warn = warn

import pandas as pd, numpy as np
import os, matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, TensorDataset

import torch
import torch.nn as nn

class Autoencoder(nn.Module):
    def __init__(self, input_size, embedding):
        super(Autoencoder, self).__init__()
        self.input_size = input_size
        self.embedding = embedding

        """
        Encoder Layers
        """
        self.conv1_enc = nn.Conv1d(in_channels = 1, out_channels = 1, kernel_size = 9, stride = 1, padding = 'same')
        self.conv2_enc = nn.Conv1d(in_channels = 1, out_channels = 1, kernel_size = 9, stride = 1, padding = 'same')
        self.conv3_enc = nn.Conv1d(in_channels = 1, out_channels = 1, kernel_size = 5, stride = 1, padding = 'same')
        self.conv4_enc = nn.Conv1d(in_channels = 1, out_channels = 1, kernel_size = 3, stride = 1, padding = 'same')
        self.conv5_enc = nn.Conv1d(in_channels = 1, out_channels = 1, kernel_size = 3, stride = 1, padding = 'same')

        self.pool1_enc = nn.MaxPool1d(kernel_size = 2, stride = 2, padding = 0)
        self.pool2_enc = nn.MaxPool1d(kernel_size = 2, stride = 2, padding = 0)
        self.pool3_enc = nn.MaxPool1d(kernel_size = 2, stride = 2, padding = 0)
        self.pool4_enc = nn.MaxPool1d(kernel_size = 2, stride = 2, padding = 0)

        #self.dense1_enc = nn.Linear(int(self.input_size / 16), embedding)
        self.dense1_enc = nn.Linear(int(self.input_size / 16), int(self.input_size / 16))

        self.norm1_enc = nn.BatchNorm1d(1)
        self.norm2_enc = nn.BatchNorm1d(1)
        self.norm3_enc = nn.BatchNorm1d(1)
        self.norm4_enc = nn.BatchNorm1d(1)
        self.norm5_enc = nn.BatchNorm1d(1)
        self.norm6_enc = nn.BatchNorm1d(16)

        self.relu = nn.ReLU()

        """
        Decoder Layers
        """

        self.conv1_dec = nn.ConvTranspose1d(in_channels = 1, out_channels = 1, kernel_size = 3, stride = 1, padding = 1)
        self.conv2_dec = nn.ConvTranspose1d(in_channels = 1, out_channels = 1, kernel_size = 3, stride = 1, padding = 1)
        self.conv3_dec = nn.ConvTranspose1d(in_channels = 1, out_channels = 1, kernel_size = 5, stride = 1, padding = 2)
        self.conv4_dec = nn.ConvTranspose1d(in_channels = 1, out_channels = 1, kernel_size = 9, stride = 1, padding = 4)
        self.conv5_dec = nn.ConvTranspose1d(in_channels = 1, out_channels = 1, kernel_size = 9, stride = 1, padding = 4)

        self.up1_dec = nn.Upsample(size = int(self.input_size / 8), mode = "nearest")
        self.up2_dec = nn.Upsample(size = int(self.input_size / 4), mode = "nearest")
        self.up3_dec = nn.Upsample(size = int(self.input_size / 2), mode = "nearest")
        self.up4_dec = nn.Upsample(size = int(self.input_size), mode = "nearest")

        #self.dense1_dec = nn.Linear(embedding, int(self.input_size / 16))
        self.dense1_dec = nn.Linear(int(self.input_size / 16), int(self.input_size / 16))


        self.norm1_dec = nn.BatchNorm1d(16)
        self.norm2_dec = nn.BatchNorm1d(1)
        self.norm3_dec = nn.BatchNorm1d(1)
        self.norm4_dec = nn.BatchNorm1d(1)
        self.norm5_dec = nn.BatchNorm1d(1)
        self.norm6_dec = nn.BatchNorm1d(1)

        self.sigmoid = nn.Sigmoid()
        #self.softmax = nn.softmax()

    def forward(self, activation):
        #print('shape0: ', activation.size())
        encoded = self.conv1_enc(activation)
        encoded = self.norm1_enc(encoded)
        encoded = self.relu(encoded)

        encoded = self.conv2_enc(encoded)
        encoded = self.norm2_enc(encoded)
        encoded = self.relu(encoded)

        encoded = self.pool1_enc(encoded)

        encoded = self.conv3_enc(encoded)
        encoded = self.norm3_enc(encoded)
        encoded = self.relu(encoded)

        encoded = self.pool2_enc(encoded)

        encoded = self.conv4_enc(encoded)
        encoded = self.norm4_enc(encoded)
        encoded = self.relu(encoded)

        encoded = self.pool3_enc(encoded)

        encoded = self.conv5_enc(encoded)
        encoded = self.norm5_enc(encoded)
        encoded = self.relu(encoded)

        encoded = self.pool4_enc(encoded)

        self.orig_shape = encoded.size()

        encoded = self.flatten(encoded)
        embedding = self.dense1_enc(encoded)
        embedding = self.norm6_enc(embedding)
        embedding = self.relu(embedding)

        decoded = self.dense1_dec(embedding)
        decoded = self.norm1_dec(decoded)
        decoded = self.relu(decoded)
        decoded = self.unflatten(decoded)

        decoded = self.up1_dec(decoded)

        decoded = self.conv1_dec(decoded)
        decoded = self.norm2_dec(decoded)
        decoded = self.relu(decoded)

        decoded = self.conv2_dec(decoded)
        decoded = self.norm3_dec(decoded)
        decoded = self.relu(decoded)

        decoded = self.up2_dec(decoded)

        decoded = self.conv3_dec(decoded)
        decoded = self.norm4_dec(decoded)
        decoded = self.relu(decoded)

        decoded = self.up3_dec(decoded)

        decoded = self.conv4_dec(decoded)
        decoded = self.norm5_dec(decoded)
        decoded = self.relu(decoded)

        decoded = self.up4_dec(decoded)

        decoded = self.conv5_dec(decoded)
        decoded = self.norm6_dec(decoded)
        decoded = self.relu(decoded)

        decoded = self.sigmoid(decoded)
        #decoded = self.softmax()

        return decoded

    def get_embedding(self, activation):
        encoded = self.conv1_enc(activation)
        encoded = self.norm1_enc(encoded)
        encoded = self.relu(encoded)

        encoded = self.conv2_enc(encoded)
        encoded = self.norm2_enc(encoded)
        encoded = self.relu(encoded)

        encoded = self.pool1_enc(encoded)

        encoded = self.conv3_enc(encoded)
        encoded = self.norm3_enc(encoded)
        encoded = self.relu(encoded)

        encoded = self.pool2_enc(encoded)

        encoded = self.conv4_enc(encoded)
        encoded = self.norm4_enc(encoded)
        encoded = self.relu(encoded)

        encoded = self.pool3_enc(encoded)

        encoded = self.conv5_enc(encoded)
        encoded = self.norm5_enc(encoded)
        encoded = self.relu(encoded)

        encoded = self.pool4_enc(encoded)

        encoded = self.flatten(encoded)
        embedding = self.dense1_enc(encoded)
        embedding = self.norm6_enc(embedding)
        embedding = self.relu(embedding)

        return embedding


    def flatten(self, x):
        x = x.view(x.size(0), -1)
        return x


    def unflatten(self, x):
        x = x.view(self.orig_shape)
        return x

When I am trying to train this network, the loss goes down very fast in the beginning and then stops changing. I also used wandb to plot the gradients and parameters and turned out gradient for weight and biases for quite a few convolutional layer and BatchNorm layers are zero.

Could anyone give me some suggestions about why this might happen? It would be great if I receive some code specific answer, but general suggestions or link to some resource that I can read and try to debug are also appreciated. Also, please let me know if anyone needs more information. Thank you in advance!