Hello,
I’m trying to implement a LSTM-VAE to make anomalies detection. I have a problem ; the model is only able to learn a flat curve (like the mean) instead the complex signals in input.
My data are mutlivariate timeseries (3 channels) with no constant duration.
To simplify the dataloader, for the moment, I don’t use batching (pack, padd, mask). So my training loop ingest cycle by cycle.
I use data scaling (robust or standard), unidirectionnal multilayers LSTM with batchnorm and cyclical kl_annealing to ensure a good repartition in the loss, between the MSE (LSTM reconstruction error) and the KLD (latent Gaussain learnt).
Hereunder an reconstrcuted signal:
The first points of the reconstructed signal try to learn but quickly, the LSTM failed and learn a flat curve.
I tried Attention layer, differents size of hidden layers, conv1d, bLSTM, … and no improvement.
Is-it a problem of network architecture or a pytorch implementation ?
My model:
class LSTMVAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim, num_layers=1, dropout=0.3, device="cpu"):
super(LSTMVAE, self).__init__()
self.device = torch.device(device)
self.to(self.device)
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
self.num_layers = num_layers
self.dropout = nn.Dropout(dropout)
# Encoder
self.encoder_lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, batch_first=True).to(self.device)
self.encoder_ln = nn.LayerNorm(hidden_dim).to(self.device)
self.hidden_to_mean = nn.Linear(hidden_dim, latent_dim).to(self.device)
self.hidden_to_logvar = nn.Linear(hidden_dim, latent_dim).to(self.device)
# Decoder
self.latent_to_hidden = nn.Linear(latent_dim, hidden_dim).to(self.device)
self.decoder_lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, batch_first=True).to(self.device)
self.decoder_ln = nn.LayerNorm(hidden_dim).to(self.device)
self.hidden_to_output = nn.Linear(hidden_dim, input_dim).to(self.device)
# Residual layers
# self.residual_layer = nn.Sequential(nn.Linear(input_dim, input_dim), nn.Dropout(dropout), ).to(
# self.device).to(self.device)
# Weight initialization
self._initialize_weights()
def encode(self, x):
x = self.dropout(x)
_, (h, _) = self.encoder_lstm(x) # h : (num_layers, batch_size, hidden_dim)
h = h[-1] # last layer : (batch_size, hidden_dim)
mean = self.hidden_to_mean(h)
logvar = self.hidden_to_logvar(h)
return mean, logvar
def reparameterize(self, mean, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mean + eps * std
def decode(self, z, seq_len):
# Convert z to LSTM initial states
h = self.latent_to_hidden(z).unsqueeze(0).repeat(self.num_layers, 1, 1) # (1, batch_size, hidden_dim)
c = torch.zeros_like(h) # Initial memory cell
input_seq = torch.zeros((z.size(0), seq_len, self.input_dim)).to(z.device)
output, _ = self.decoder_lstm(input_seq, (h, c))
recon_x = self.hidden_to_output(output)
return recon_x
def forward(self, x, seq_len):
x = x.to(self.device)
mean, logvar = self.encode(x)
z = self.reparameterize(mean, logvar)
recon_x = self.decode(z, seq_len)
# Add residual connections
# recon_x = recon_x + self.residual_layer(x)
return recon_x, mean, logvar
def _initialize_weights(self):
for name, p in self.named_parameters():
if 'lstm' in name:
if 'weight_ih' in name:
nn.init.xavier_uniform_(p.data)
elif 'weight_hh' in name:
nn.init.orthogonal_(p.data)
elif 'bias_ih' in name:
p.data.fill_(0)
# Set forget-gate bias to 1
n = p.size(0)
p.data[(n // 4):(n // 2)].fill_(1)
elif 'bias_hh' in name:
p.data.fill_(0)
elif 'fc' in name:
if 'weight' in name:
nn.init.xavier_uniform_(p.data)
elif 'bias' in name:
p.data.fill_(0)
And my train loop:
model.train() step = 0 for epoch in range(cfg.lstm.epoch): for cycle, id_cycle, activity, _, seq_len in train_loader: logger.debug(f'Cycles: {id_cycle.item()}, activities: {activity}, seq_lens: {seq_len.item()}') # Push cycle tensor and se_len variable to GPU if present cycle, seq_len = cycle.to(device), seq_len.to(device) optimizer.zero_grad() recon_cycle, mean, logvar = model(cycle, seq_len) recon_error = nn.MSELoss()(recon_cycle, cycle).item() recon_errors_list.append(recon_error) loss = loss_function(recon_cycle, cycle, mean, logvar, anneal_function=cfg.lstm.kl_anneal_function, step=step, total_steps=total_steps, cycles=cfg.lstm.kl_anneal_cycles, ratio=cfg.lstm.kl_anneal_ratio) loss.backward() optimizer.step() step += 1 plot_reconstructed_signal(cycle.detach().cpu().numpy(), recon_cycle.detach().cpu().numpy(), id_cycle.item(), activity[0], recon_error=recon_error, status="Training", save_dir=save_dir) logger.info(f"Epoch {epoch + 1}, Loss: {loss.item()}")
Is it a problem of network architecture or a problem of pytorch implementation ? I’m confused !!
I tried a skip connection (residual) connexion and, of course, the results become good.
Regards.
Rémy