I am trying to train a VAE to replicate Figure 1c in this paper:
In other words, I want to train a VAE on 2D points that form a pinwheel (or really any weird shape) and then generate or sample 2D points from the trained VAE that are also in a pinwheel. I am using the example PyTorch VAE with two modifications:
-
All the layers are 2-dimensional:
self.fc1 = nn.Linear(2, 2) self.fc21 = nn.Linear(2, 2) self.fc22 = nn.Linear(2, 2) self.fc3 = nn.Linear(2, 2) self.fc4 = nn.Linear(2, 2)
-
I changed the loss to
F.mse_loss
(also tried Euclidean distance) instead ofF.binary_cross_entropy
since I am dealing with cartesian points instead of black and white images.
That said, I am unable to learn the pinwheel dataset. Here is what happens if I sample from the VAE 10000 times:
I’ve tried both Adam and SGD with lr=0.001. Any idea what is going wrong?