I have a VAE that is trained to encode and generate new sentences given a dataset of existing sentences. The way I chose to do it is I built another Autoencoder that creates a representation of a sentence which can be decoded back to the original sentence (reconstuction is perfect). Then I use the encodings generated by the AE as input for the VAE. The VAE in turn takes in a sentence, embeds every token in a 64-dimensional space and then permorms the reparametrization trick in order to recreate the original embedding that the AE has produced. Essentially I turn sentences into something that looks like an image using an AE and then I have the VAE recreate that “image”. Then I use the original AE’s decoder part to turn that “image” back into a sentence.
I use BCELoss for the reconstruction and KL Divergence loss in order to force the model to stick to a normal distribution. The loss begins at 4000 or so and slowly decreases down to 0.08, but the results are always the same nonsensical sentence. In fact, all the VAE seems to have learned is to output numbers that are close to 0.5, which of course generates the same sentence when decoded. It is worth noting that the AE which generates the encoding of the sentence uses a sigmoid layer in order to keep the encoded vector within a range of 0 to 1 (much like an image would be). The AE looks like this:
class CNNEncode(nn.Module):
def __init__(self):
super().__init__()
# INFO
# LAYERS
self.embed = nn.Embedding(vocab_size, n_embed*4)
self.conv1 = nn.Conv1d(n_embed*4, max_l, 3, stride=3, padding=1, dilation = 2)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv1d(max_l, n_embed*1, 3, stride=3, padding=1, dilation = 4)
self.fc = nn.Linear(512, 256)
self.fc2 = nn.Linear(256, 64*max_l)
self.ln = nn.LayerNorm(256)
self.ln2 = nn.LayerNorm(64, 1)
self.out = nn.Linear(64, vocab_size)
self.sig = nn.Sigmoid()
def forward(self, x):
emb = self.embed(x)
B,T,C = emb.shape
emb = emb.view(B, C, T)
logits = self.conv1(emb)
logits = self.relu(logits)
#print(logits.shape)
logits = self.conv2(logits)
logits = self.relu(logits)
#print(logits.shape)
B,T,C = logits.shape
logits = logits.view(B, -1)
logits = torch.tanh(logits)
#print(logits.shape)
logits = self.fc(logits)
#logits = torch.tanh(logits)
logits = self.ln(logits)
logits = self.sig(logits)
#print(logits)
logits = self.fc2(logits)
#logits = torch.tanh(logits)
#logits = self.ln2(logits)
logits = self.relu(logits)
logits = logits.view(B, max_l, 64)
#print(logits.shape)
out = self.out(logits)
return out
and the VAE looks like this:
class VariationalAutoEncoder(nn.Module):
def __init__(self, input_dim, n_dim=128, h_dim=100, z_dim=64):
super().__init__()
#encoder, input is the embedding extracted from network A
self.img_2hid = nn.Linear(256, h_dim)
self.hid_2mu = nn.Linear(h_dim, z_dim)
self.hid_2sigma = nn.Linear(h_dim, z_dim)
#decode
self.z_2hid = nn.Linear(z_dim, h_dim)
self.hid_2img = nn.Linear(h_dim, 256)
self.relu = nn.ReLU()
self.sig = nn.Sigmoid()
def encode(self, x):
h = self.relu(self.img_2hid(x))
h = self.sig(h)
#print("h", h.shape)
mu = self.hid_2mu(h)
sigma = self.hid_2sigma(h)
return mu, sigma
def decode(self, z):
h = self.relu(self.z_2hid(z))
#print("dec h", h.shape)
return self.sig(self.hid_2img(h))
def forward(self, x):
mu, sigma = self.encode(x)
epsilon = torch.randn_like(sigma)
z_parametrized = mu + sigma* epsilon
x_reconstructed = self.decode(z_parametrized)
return x_reconstructed, mu, sigma
model = VariationalAutoEncoder(input_dim=max_l, n_dim=128, h_dim=200, z_dim=64).to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr)
print(sum(p.nelement() for p in model.parameters()))
loss_fn = nn.BCELoss(reduction = "sum")
Why do I keep getting erroneous results even when loss is close to 0? And how should I approach the problem? I can provide more info if necessary.