TransformerDecoderLayer AssertionError

I read an interesting paper on transformer based variational autoencoder (VAE), so I tried to replicate a simpler model using nn.TransformerEncoderLayer and nn.TransformerDecoderLayer .The architecture is given code :-

embedding_size = 256
n_heads = 8
batch_size = 16
latent_dim = 32

class Transformer_VAE(nn.Module):
    
    def __init__(self, head, vocab_size, embedding_size, latent_dim, device = 'cpu', pad_idx = 0, start_idx = 1, end_idx = 2, unk_idx = 3):
        super(Transformer_VAE, self).__init__()
        self.head = head
        self.embedding_size = embedding_size
        self.vocab_size = vocab_size
        self.latent_dim = latent_dim
        self.embed = WordEmbedding(self.vocab_size, self.embedding_size)
        self.postional_encoding = PostionalEncoding(embedding_size, device)
        self.encoder = nn.TransformerEncoderLayer(self.embedding_size, head)
        self.decoder = nn.TransformerDecoderLayer(self.latent_dim, head)
        self.hidden_to_mean = nn.Linear(self.embedding_size, latent_dim)
        self.hidden_to_logvar = nn.Linear(self.embedding_size, latent_dim)
        self.out_linear = nn.Linear(self.embedding_size, vocab_size)
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
        
    def forward(self, x):
        batch_size, maxlen = x.size()[:2]
        src = self.embed(x)
        x = self.postional_encoding(src)
        x = self.encoder(x)
        print(x.size())
        mean, logvar = self.hidden_to_mean(x), self.hidden_to_logvar(x)
        z = self.reparameterize(mean, logvar)
        out = self.decoder(src, z)
        out = self.out_linear(out)
        return mean, logvar, out

I’m getting an AssertionError:, when I run the self.decoder(src, z) line. From what I understood from the documentation the TransformerDecoderLayer takes the target value and the output of the encoder as its inputs. Since this is a VAE, the target will be the same as the input and the encoder outputs have been mapped to the latent dimension. So these two values are input to the decoder layer.

My bad, Solved it by upscaling the latent dimension to embedding size before passing it to the TransformerDecoderLayer.