LSTM-VAE Unable to Reconstruct Input Time Series

I created an artificial dataset of sine curves of varying frequencies and built an LSTM-VAE to reconstruct the data and see if the model can separate different frequencies in the latent space.

From what I’ve seen, my implementation is pretty by-the-book: the encoder takes in a univariate time series and I take hidden[-1, :, :] from its output, which is of length hidden_dim. This is passed into 2 FC layers separately to produce mean and logvar, both of length latent_dim. I do the reparameterization trick to get z, which is passed through another FC layer to get back to hidden_dim. This vector is repeated seq_len times to provide the input to the decoding lstm layer, which produces reconstructions.

I believe I’m computing the loss correctly as well (reconstruction + beta*KL loss). Over 100 epochs, the reconstruction loss (1.1->0.99) remains stagnant. KL divergence is much smaller in magnitude, and either decreases or remains stagnant over the course of training. In the latent space, I actually see the 3 different frequencies separated into 3 different parts of the latent space, which is cool. Unfortunately, the reconstructed input is just a flat line.

I’ve tried different hidden/latent sizes, different values of beta (even 0, eliminating KL divergence entirely), and different learning rates and have had no success, which leads me to believe my implementation is incorrect somewhere. Unfortunately, I’ve had no luck debugging it.

Does anyone have any ideas on what might be wrong?

Model code:

class Encoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, num_layers):
        super(Encoder, self).__init__()
        self.lstm = torch.nn.LSTM(
            input_dim,
            hidden_dim,
            num_layers,
            batch_first=True,
            bidirectional=False,
        )

    def forward(self, x):
        x, (hidden, cell) = self.lstm(x)
        return x, (hidden, cell)

class Sampler(torch.nn.Module):
    def __init__(self):
        super(Sampler, self).__init__()

    def forward(self, z_mean, z_logvar):
        epsilon = torch.randn_like(z_mean)
        z = z_mean + torch.exp(0.5 * z_logvar) * epsilon
        return z

class Decoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, batch_size, seq_len, num_layers):
        super(Decoder, self).__init__()
        self.seq_len = seq_len

        self.fc = torch.nn.Linear(
            in_features=input_dim,
            out_features=hidden_dim,
        )
        self.lstm = torch.nn.LSTM(
            hidden_dim,
            output_dim,
            num_layers,
            batch_first=True,
            bidirectional=False,
        )

    def forward(self, x, hidden):
        x = self.fc(x)
        x = x.unsqueeze(1).repeat(1, self.seq_len, 1)
        x, (hidden, cell) = self.lstm(x)
        return x, (hidden, cell)

class VAE(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, batch_size, seq_len, num_encoder_layers, num_decoder_layers):
        super(VAE, self).__init__()

        self.encoder = Encoder(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            latent_dim=latent_dim,
            num_layers=num_encoder_layers,
        )
        self.fc_mean = torch.nn.Linear(
            in_features=hidden_dim,
            out_features=latent_dim,
        )
        self.fc_logvar = torch.nn.Linear(
            in_features=hidden_dim,
            out_features=latent_dim,
        )

        self.sampler = Sampler()
        self.decoder = Decoder(
            input_dim=latent_dim,
            hidden_dim=hidden_dim,
            output_dim=input_dim,
            batch_size=batch_size,
            seq_len=seq_len,
            num_layers=num_decoder_layers,
        )

    def forward(self, x):
        x, (hidden, cell) = self.encoder(x)

        z_mean = self.fc_mean(hidden[-1, :, :])
        z_logvar = self.fc_logvar(hidden[-1, :, :])
        
        if self.training:
            z = self.sampler(z_mean, z_logvar)
        else:
            z = z_mean
        
        reconstruction, (hidden, cell) = self.decoder(z, hidden)
        return reconstruction, z_mean, z_logvar

class VAELoss(torch.nn.Module):
    def __init__(self, beta):
        super(VAELoss, self).__init__()
        self.beta = beta

        self.reconstruction_loss = torch.nn.MSELoss()

    def forward(self, x, reconstruction, z_mean, z_logvar):
        reconstruction_loss = self.reconstruction_loss(reconstruction, x)

        kl_loss = -0.5 * torch.mean(1 + z_logvar - torch.pow(z_mean, 2) - torch.exp(z_logvar))

        total_loss = reconstruction_loss + self.beta * kl_loss
        return total_loss, reconstruction_loss, kl_loss

Training loop:

model = VAE(
    input_dim=1,
    hidden_dim=64,
    latent_dim=32,
    batch_size=300,
    seq_len=100,
    num_encoder_layers=1,
    num_decoder_layers=1,
).to(device)

loss_fn = VAELoss(
    beta=0.1,
)

opt = torch.optim.Adam(
    model.parameters(),
    lr=0.001,
)

model.train()
for epoch in range(100):

    for data_matrix, labels in dataloader:
        data_matrix = data_matrix.to(device)

        reconstruction, z_mean, z_logvar = model(data_matrix)

        total_loss, reconstruction_loss, kl_loss = loss_fn(data_matrix, reconstruction, z_mean, z_logvar)

        opt.zero_grad()
        total_loss.backward()
        opt.step()

    print(f"Epoch: {epoch + 1}, Total Loss: {total_loss}, Reconstruction Loss: {reconstruction_loss}, KL Loss: {kl_loss}")

Latent space, loss curves, and reconstruction example: