I read an interesting paper on transformer based variational autoencoder (VAE), so I tried to replicate a simpler model using nn.TransformerEncoderLayer
and nn.TransformerDecoderLayer
.The architecture is given code :-
embedding_size = 256
n_heads = 8
batch_size = 16
latent_dim = 32
class Transformer_VAE(nn.Module):
def __init__(self, head, vocab_size, embedding_size, latent_dim, device = 'cpu', pad_idx = 0, start_idx = 1, end_idx = 2, unk_idx = 3):
super(Transformer_VAE, self).__init__()
self.head = head
self.embedding_size = embedding_size
self.vocab_size = vocab_size
self.latent_dim = latent_dim
self.embed = WordEmbedding(self.vocab_size, self.embedding_size)
self.postional_encoding = PostionalEncoding(embedding_size, device)
self.encoder = nn.TransformerEncoderLayer(self.embedding_size, head)
self.decoder = nn.TransformerDecoderLayer(self.latent_dim, head)
self.hidden_to_mean = nn.Linear(self.embedding_size, latent_dim)
self.hidden_to_logvar = nn.Linear(self.embedding_size, latent_dim)
self.out_linear = nn.Linear(self.embedding_size, vocab_size)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def forward(self, x):
batch_size, maxlen = x.size()[:2]
src = self.embed(x)
x = self.postional_encoding(src)
x = self.encoder(x)
print(x.size())
mean, logvar = self.hidden_to_mean(x), self.hidden_to_logvar(x)
z = self.reparameterize(mean, logvar)
out = self.decoder(src, z)
out = self.out_linear(out)
return mean, logvar, out
I’m getting an AssertionError:
, when I run the self.decoder(src, z)
line. From what I understood from the documentation the TransformerDecoderLayer
takes the target value and the output of the encoder as its inputs. Since this is a VAE, the target will be the same as the input and the encoder outputs have been mapped to the latent dimension. So these two values are input to the decoder layer.