VAE on pinwheel of 2D points

I am trying to train a VAE to replicate Figure 1c in this paper:

22%20PM

In other words, I want to train a VAE on 2D points that form a pinwheel (or really any weird shape) and then generate or sample 2D points from the trained VAE that are also in a pinwheel. I am using the example PyTorch VAE with two modifications:

  1. All the layers are 2-dimensional:

    self.fc1  = nn.Linear(2, 2)
    self.fc21 = nn.Linear(2, 2)
    self.fc22 = nn.Linear(2, 2)
    self.fc3  = nn.Linear(2, 2)
    self.fc4  = nn.Linear(2, 2)
    
  2. I changed the loss to F.mse_loss (also tried Euclidean distance) instead of F.binary_cross_entropy since I am dealing with cartesian points instead of black and white images.

That said, I am unable to learn the pinwheel dataset. Here is what happens if I sample from the VAE 10000 times:

mse_loss

I’ve tried both Adam and SGD with lr=0.001. Any idea what is going wrong?

I have a similar problem, have you found a solution by now?

It’s been a while, but I think I was able to get this work by mimicking the architecture in the paper, which is roughly linear > tanh > linear > tanh. See the definition here: https://github.com/mattjj/svae/blob/master/experiments/gmm_svae_synth.py.

1 Like