I’m building a transformer-based architecture using TransformerEncoderLayer
and TransformerDecoderLayer
provided by PyTorch. The model was running properly without any masks. Since the trained model was printing <pad>
in the output, I tried passing a src_key_padding_mask
and tgt_key_padding_mask
.The mask is created by the method make_padding_key_mask()
which returns a ByteTensor of shape (N,S)
or (N,T)
. Now model started getting nan
outputs from encoder as well as decoder. Is there anything wrong with way I implemented make_padding_key_mask()
class Transformer_VAE(nn.Module):
def __init__(self, head, vocab_size, embedding_size, latent_dim, device = 'cpu', pad_idx = 0, start_idx = 1, end_idx = 2, unk_idx = 3):
super(Transformer_VAE, self).__init__()
self.head = head
self.embedding_size = embedding_size
self.vocab_size = vocab_size
self.latent_dim = latent_dim
self.device = device
self.embed = WordEmbedding(self.vocab_size, self.embedding_size)
self.postional_encoding = PostionalEncoding(embedding_size, device)
self.encoder = nn.TransformerEncoderLayer(self.embedding_size, self.head)
self.decoder = nn.TransformerDecoderLayer(self.embedding_size, self.head)
self.hidden_to_mean = nn.Linear(self.embedding_size, latent_dim)
self.hidden_to_logvar = nn.Linear(self.embedding_size, latent_dim)
self.latent_to_embed = nn.Linear(self.latent_dim, self.embedding_size)
self.out_linear = nn.Linear(self.embedding_size, vocab_size)
self.pad_idx = pad_idx
self.start_idx = start_idx
self.end_idx = end_idx
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def make_padding_key_mask(self, x):
batch_size, max_len = x.size()[:2]
pad_mask = (x != self.pad_idx).expand(batch_size, max_len)
pad_mask.type(torch.ByteTensor)
return pad_mask
def make_memory_mask(self, tgt_sz, src_sz):
mask = (torch.triu(torch.ones(src_sz, tgt_sz)) == 1).transpose(0, 1)
mem_mask = mask.float().masked_fill(mask == 0, 10e-9).masked_fill(mask == 1, float(0.0))
return mem_mask
def make_tgt_mask(self, tgt_len):
tgt_mask = torch.tril(torch.ones((tgt_len, tgt_len)))
if self.device=='gpu':
tgt_mask = tgt_mask.cuda()
#tgt_mask = tgt_mask.view(batch_size, 1, tgt_len, tgt_len)
return tgt_mask
def get_tgt_embedding(self, tgt, pred):
pred = torch.argmax(pred, dim = -1)
tgt_indices = torch.cat([tgt, pred], dim = 0)
tgt_embed = self.embed(tgt_indices)
return tgt_indices, tgt_embed
def decode(self, tgt, z, tgt_mask, tgt_pad_mask):
z = self.latent_to_embed(z)
out = self.decoder(tgt, z, tgt_mask, tgt_key_padding_mask = tgt_pad_mask)
out = self.out_linear(out)
return out
def forward(self, x):
batch_size, maxlen = x.size()[:2]
tgt_mask = self.make_tgt_mask(maxlen)
mem_mask = self.make_memory_mask(maxlen, maxlen)
src_padding_mask = self.make_padding_key_mask(x)
tgt_padding_mask = self.make_padding_key_mask(x)
if self.device == 'gpu':
tgt_mask = tgt_mask.cuda()
src = self.embed(x).view(maxlen, batch_size, self.embedding_size)
src = self.postional_encoding(src)
x = self.encoder(src, src_key_padding_mask = src_padding_mask)
print(torch.sum(torch.isnan(x)))
mean, logvar = self.hidden_to_mean(x), self.hidden_to_logvar(x)
z = self.reparameterize(mean, logvar)
out = self.decode(src, z, tgt_mask, tgt_padding_mask)
return mean, logvar, out