Vae can overfit on one data sample, but outputs blank/noise if trained with more data

Hello,
I’d really appreciate some insight on this, since I’ve been struggling for a few days now.

Here is my model’s architecture:

latent_size = 5
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(1, 16, 3, stride=3),
            nn.ReLU(),
            nn.Conv1d(16, 32, 3, stride=3),
            nn.ReLU(),
            nn.Conv1d(32, 1, 3),
            nn.ReLU(),

        )
        self.mean = nn.Linear(6942, latent_size)
        self.logvar = nn.Linear(6942, latent_size)

        self.decoder = nn.Sequential(
            nn.Linear(latent_size, 6942),
            nn.ReLU(),
            nn.ConvTranspose1d(1, 64, 3, stride=3),
            nn.ReLU(),
            nn.ConvTranspose1d(64, 32, 3, stride=3),
            nn.ReLU(),
            nn.ConvTranspose1d(32, 1, 23),
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + (eps * std)
    
    def forward(self, x):
        encode = self.encoder(x)
        encode = encode.view(1, 6942)
        mu = self.mean(encode)
        logvar = self.logvar(encode)
        z = self.reparameterize(mu, logvar)
        decode = self.decoder(z)
        return decode, mu, logvar
    
   
net = VAE()

And here is the loss function:

def VAE_loss(x, x_hat, mu, logvar):
    reconstructed_loss = F.cross_entropy(x_hat, x, reduction='sum')
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return reconstructed_loss + kl_divergence

loss = VAE_loss

Sincere thanks to anyone who can point me in the right direction. :slightly_smiling_face: