I’m training a VAE (let’s call this VAE_combined) on the embeddings of three other VAEs (let’s call them VAE_1, VAE_2, and VAE_3). I’m wondering whether I should generate the training dataset for the VAE_combined using VAE_1/2/3.train() or VAE_1/2/3.eval()? Both the test and train losses for VAE_combined when training on embeddings generated using .train() are significantly lower than the losses obtained on embeddings generated using .eval(). But I’m not sure which one is “correct”.
You’re spot-on about freezing the base VAEs to keep their domain-specific latents. Use .eval() and requires_grad_(False) to ensure they’re fixed:
for vae in [base_vae1, base_vae2, base_vae3]:
vae.eval()
for param in vae.parameters():
param.requires_grad_(False)
A few tips to avoid pitfalls:
- Normalize Latents: Z-scale base VAE latents (subtract mean, divide by std) if their scales differ:
z1 = (z1 - z1.mean()) / z1.std()
z_input = torch.cat([z1, z2, z3], dim=-1)
- Use MSELoss for reconstruction (input vs. output latents) plus KL divergence. Consider β-VAE (e.g., β=0.1 initially) to balance KL:
recon_loss = torch.nn.MSELoss()(z_recon, z_input)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = recon_loss + beta * kl_loss
- VAEs can be tricky. If loss isn’t converging, try a lower learning rate (e.g., 1e-4), KL annealing (β from 0 to 1), or monitoring for posterior collapse (KL near zero). Plot loss curves before extending training.
- Partial Inputs: For inference with fewer latents, pad missing ones with zeros or their mean to match input size.
Thanks for your reply!
What’s the exact reason I should generate embeddings using .eval()? As I mentioned, I get significantly better results when generating using .train(), so it’s tempting to use those.
You said you get better results. But how many epochs are we talking? Typically, VAEs take a very long time to train (like an order of 100x what other types of networks take). Can you show a chart of loss both with the base VAEs set on eval and on train? I’m wondering if you’re training long enough to get proper convergence.
I don’t have any charts, but I’m training on a small dataset of just 100 data points (the ultimate goal is to build a larger dataset with active learning). Using the same dataset over 100 epochs, the reconstruction error is 3x lower using .train(). By that time, the network seems to have converged.