Bad reconstruction with LSTM

Hello,

I’m trying to implement a LSTM-VAE to make anomalies detection. I have a problem ; the model is only able to learn a flat curve (like the mean) instead the complex signals in input.
My data are mutlivariate timeseries (3 channels) with no constant duration.
To simplify the dataloader, for the moment, I don’t use batching (pack, padd, mask). So my training loop ingest cycle by cycle.
I use data scaling (robust or standard), unidirectionnal multilayers LSTM with batchnorm and cyclical kl_annealing to ensure a good repartition in the loss, between the MSE (LSTM reconstruction error) and the KLD (latent Gaussain learnt).

Hereunder an reconstrcuted signal:

The first points of the reconstructed signal try to learn but quickly, the LSTM failed and learn a flat curve.
I tried Attention layer, differents size of hidden layers, conv1d, bLSTM, … and no improvement.

Is-it a problem of network architecture or a pytorch implementation ?

My model:

class LSTMVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, num_layers=1, dropout=0.3, device="cpu"):
        super(LSTMVAE, self).__init__()
        self.device = torch.device(device)
        self.to(self.device)

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.num_layers = num_layers
        self.dropout = nn.Dropout(dropout)

        # Encoder
        self.encoder_lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, batch_first=True).to(self.device)
        self.encoder_ln = nn.LayerNorm(hidden_dim).to(self.device)
        self.hidden_to_mean = nn.Linear(hidden_dim, latent_dim).to(self.device)
        self.hidden_to_logvar = nn.Linear(hidden_dim, latent_dim).to(self.device)

        # Decoder
        self.latent_to_hidden = nn.Linear(latent_dim, hidden_dim).to(self.device)
        self.decoder_lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, batch_first=True).to(self.device)
        self.decoder_ln = nn.LayerNorm(hidden_dim).to(self.device)
        self.hidden_to_output = nn.Linear(hidden_dim, input_dim).to(self.device)

        # Residual layers
        # self.residual_layer = nn.Sequential(nn.Linear(input_dim, input_dim), nn.Dropout(dropout), ).to(
        #     self.device).to(self.device)

        # Weight initialization
        self._initialize_weights()

    def encode(self, x):
        x = self.dropout(x)
        _, (h, _) = self.encoder_lstm(x)  # h : (num_layers, batch_size, hidden_dim)
        h = h[-1]  # last layer : (batch_size, hidden_dim)
        mean = self.hidden_to_mean(h)
        logvar = self.hidden_to_logvar(h)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def decode(self, z, seq_len):
        # Convert z to LSTM initial states
        h = self.latent_to_hidden(z).unsqueeze(0).repeat(self.num_layers, 1, 1)  # (1, batch_size, hidden_dim)
        c = torch.zeros_like(h)  # Initial memory cell
        input_seq = torch.zeros((z.size(0), seq_len, self.input_dim)).to(z.device)
        output, _ = self.decoder_lstm(input_seq, (h, c))
        recon_x = self.hidden_to_output(output)
        return recon_x

    def forward(self, x, seq_len):
        x = x.to(self.device)
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        recon_x = self.decode(z, seq_len)
        # Add residual connections
        # recon_x = recon_x + self.residual_layer(x)
        return recon_x, mean, logvar

    def _initialize_weights(self):
        for name, p in self.named_parameters():
            if 'lstm' in name:
                if 'weight_ih' in name:
                    nn.init.xavier_uniform_(p.data)
                elif 'weight_hh' in name:
                    nn.init.orthogonal_(p.data)
                elif 'bias_ih' in name:
                    p.data.fill_(0)
                    # Set forget-gate bias to 1
                    n = p.size(0)
                    p.data[(n // 4):(n // 2)].fill_(1)
                elif 'bias_hh' in name:
                    p.data.fill_(0)
            elif 'fc' in name:
                if 'weight' in name:
                    nn.init.xavier_uniform_(p.data)
                elif 'bias' in name:
                    p.data.fill_(0)

And my train loop:

model.train()
step = 0
for epoch in range(cfg.lstm.epoch):
    for cycle, id_cycle, activity, _, seq_len in train_loader:
        logger.debug(f'Cycles: {id_cycle.item()}, activities: {activity}, seq_lens: {seq_len.item()}')
        # Push cycle tensor and se_len variable to GPU if present
        cycle, seq_len = cycle.to(device), seq_len.to(device)
        optimizer.zero_grad()
        recon_cycle, mean, logvar = model(cycle, seq_len)
        recon_error = nn.MSELoss()(recon_cycle, cycle).item()
        recon_errors_list.append(recon_error)
        loss = loss_function(recon_cycle, cycle, mean, logvar, anneal_function=cfg.lstm.kl_anneal_function,
                             step=step, total_steps=total_steps, cycles=cfg.lstm.kl_anneal_cycles,
                             ratio=cfg.lstm.kl_anneal_ratio)
        loss.backward()
        optimizer.step()
        step += 1

        plot_reconstructed_signal(cycle.detach().cpu().numpy(), recon_cycle.detach().cpu().numpy(), id_cycle.item(), activity[0], recon_error=recon_error, status="Training", save_dir=save_dir)

    logger.info(f"Epoch {epoch + 1}, Loss: {loss.item()}")

Is it a problem of network architecture or a problem of pytorch implementation ? I’m confused !!

I tried a skip connection (residual) connexion and, of course, the results become good.

Regards.

Rémy

I would precise that I already read posts about the same problem and:

  • My data are scaled. I tried Standard, Robust, MinMax.
  • I tried very low learning rate to catch the local variation.
  • I tried to apply very high epochs.

Always the same problem.

Rémy

I made some investigations during the training phase. I followed the gradient and I have this king of values for my gradients:

  • Zero gradient for decoder_lstm.weight_ih_l0
  • Parameter stats:
  • Mean: 0.020184
  • Std: 0.339163
  • Max: 0.577022
  • Parameter requires grad: True

My LSTM-VAE has this structure:

  • Model summary:
  • Parameter : encoder_lstm.weight_ih_l0, size : torch.Size([256, 3])
  • Parameter : encoder_lstm.weight_hh_l0, size : torch.Size([256, 64])
  • Parameter : encoder_lstm.bias_ih_l0, size : torch.Size([256])
  • Parameter : encoder_lstm.bias_hh_l0, size : torch.Size([256])
  • Parameter : encoder_lstm.weight_ih_l1, size : torch.Size([256, 64])
  • Parameter : encoder_lstm.weight_hh_l1, size : torch.Size([256, 64])
  • Parameter : encoder_lstm.bias_ih_l1, size : torch.Size([256])
  • Parameter : encoder_lstm.bias_hh_l1, size : torch.Size([256])
  • Parameter : encoder_ln.weight, size : torch.Size([64])
  • Parameter : encoder_ln.bias, size : torch.Size([64])
  • Parameter : hidden_to_mean.weight, size : torch.Size([10, 64])
  • Parameter : hidden_to_mean.bias, size : torch.Size([10])
  • Parameter : hidden_to_logvar.weight, size : torch.Size([10, 64])
  • Parameter : hidden_to_logvar.bias, size : torch.Size([10])
  • Parameter : latent_to_hidden.weight, size : torch.Size([64, 10])
  • Parameter : latent_to_hidden.bias, size : torch.Size([64])
  • Parameter : decoder_lstm.weight_ih_l0, size : torch.Size([256, 3])
  • Parameter : decoder_lstm.weight_hh_l0, size : torch.Size([256, 64])
  • Parameter : decoder_lstm.bias_ih_l0, size : torch.Size([256])
  • Parameter : decoder_lstm.bias_hh_l0, size : torch.Size([256])
  • Parameter : decoder_lstm.weight_ih_l1, size : torch.Size([256, 64])
  • Parameter : decoder_lstm.weight_hh_l1, size : torch.Size([256, 64])
  • Parameter : decoder_lstm.bias_ih_l1, size : torch.Size([256])
  • Parameter : decoder_lstm.bias_hh_l1, size : torch.Size([256])
  • Parameter : decoder_ln.weight, size : torch.Size([64])
  • Parameter : decoder_ln.bias, size : torch.Size([64])
  • Parameter : hidden_to_output.weight, size : torch.Size([3, 64])
  • Parameter : hidden_to_output.bias, size : torch.Size([3])

I’s means that my LSTM-VAE is not able to learn from the VAE latent space (the first decoder layer gradient is always O). Is there somebody know why ?
My sequences input are long (around 10000 points). Is-it the problem ? Normally LSTM is able to treat vanishing gradient with the differents internal gates.
Thanks by advance if you have ideas !!!

Regards.

Rémy

No idea ?
I discover that there are some algos to avoid this problem in LSTM : chunk the long sequence or use TBTT.
Do you know these technics ?
And do you know if GRU is more resilient ?