Loss becomesn nan when src_key_padding_mask is used

I’m building a transformer-based architecture using TransformerEncoderLayer and TransformerDecoderLayer provided by PyTorch. The model was running properly without any masks. Since the trained model was printing <pad> in the output, I tried passing a src_key_padding_mask and tgt_key_padding_mask.The mask is created by the method make_padding_key_mask() which returns a ByteTensor of shape (N,S) or (N,T). Now model started getting nan outputs from encoder as well as decoder. Is there anything wrong with way I implemented make_padding_key_mask()

class Transformer_VAE(nn.Module):
    def __init__(self, head, vocab_size, embedding_size, latent_dim, device = 'cpu', pad_idx = 0, start_idx = 1, end_idx = 2, unk_idx = 3):
        super(Transformer_VAE, self).__init__()
        self.head = head
        self.embedding_size = embedding_size
        self.vocab_size = vocab_size
        self.latent_dim = latent_dim
        self.device = device
        self.embed = WordEmbedding(self.vocab_size, self.embedding_size)
        self.postional_encoding = PostionalEncoding(embedding_size, device)
        self.encoder = nn.TransformerEncoderLayer(self.embedding_size, self.head)
        self.decoder = nn.TransformerDecoderLayer(self.embedding_size, self.head)
        self.hidden_to_mean = nn.Linear(self.embedding_size, latent_dim)
        self.hidden_to_logvar = nn.Linear(self.embedding_size, latent_dim)
        self.latent_to_embed = nn.Linear(self.latent_dim, self.embedding_size)
        self.out_linear = nn.Linear(self.embedding_size, vocab_size)
        self.pad_idx = pad_idx
        self.start_idx = start_idx
        self.end_idx = end_idx
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def make_padding_key_mask(self, x):
        batch_size, max_len = x.size()[:2]
        pad_mask = (x != self.pad_idx).expand(batch_size, max_len)
        return pad_mask

    def make_memory_mask(self, tgt_sz, src_sz):
        mask = (torch.triu(torch.ones(src_sz, tgt_sz)) == 1).transpose(0, 1)
        mem_mask = mask.float().masked_fill(mask == 0, 10e-9).masked_fill(mask == 1, float(0.0))
        return mem_mask

    def make_tgt_mask(self, tgt_len):
        tgt_mask = torch.tril(torch.ones((tgt_len, tgt_len)))
        if self.device=='gpu':
          tgt_mask = tgt_mask.cuda()
        #tgt_mask = tgt_mask.view(batch_size, 1, tgt_len, tgt_len)
        return tgt_mask

    def get_tgt_embedding(self, tgt, pred):
        pred = torch.argmax(pred, dim = -1)
        tgt_indices = torch.cat([tgt, pred], dim = 0)
        tgt_embed = self.embed(tgt_indices)
        return tgt_indices, tgt_embed

    def decode(self, tgt, z, tgt_mask, tgt_pad_mask):
        z = self.latent_to_embed(z)
        out = self.decoder(tgt, z, tgt_mask, tgt_key_padding_mask = tgt_pad_mask)
        out = self.out_linear(out)
        return out
    def forward(self, x):
        batch_size, maxlen = x.size()[:2]
        tgt_mask = self.make_tgt_mask(maxlen)
        mem_mask = self.make_memory_mask(maxlen, maxlen)
        src_padding_mask = self.make_padding_key_mask(x)
        tgt_padding_mask = self.make_padding_key_mask(x)
        if self.device == 'gpu':
            tgt_mask = tgt_mask.cuda()
        src = self.embed(x).view(maxlen, batch_size, self.embedding_size)
        src = self.postional_encoding(src) 
        x = self.encoder(src, src_key_padding_mask = src_padding_mask)
        mean, logvar = self.hidden_to_mean(x), self.hidden_to_logvar(x)
        z = self.reparameterize(mean, logvar)
        out = self.decode(src, z, tgt_mask, tgt_padding_mask)
        return mean, logvar, out