My transformer NMT model is giving "nan" loss value

I am training my transformer model and my model’s loss is “nan”. I have tried various workarounds but couldn’t figure it out. Please some help me in this.
Here is my transformer model.

import torch.nn as nn
import torch
import math

class PositionalEncoding(nn.Module):
    def __init__(self, model_dim, dropout_rate=0.2, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.model_dim = model_dim
        self.dropout = nn.Dropout(dropout_rate)
        pos_enc = torch.zeros(max_len, model_dim)
        position = torch.arange(0, max_len, dtype=torch.float).view(-1,1)
        div_term = torch.exp(torch.arange(0, model_dim, 2).float() * (-math.log(10000.0) / model_dim))
        pos_enc[:, 0::2] = torch.sin(position * div_term)
        pos_enc[:, 1::2] = torch.cos(position * div_term)
        pos_enc = pos_enc.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pos_enc', pos_enc)

    def forward(self, inputs):
        inputs = inputs * math.sqrt(self.model_dim)
        inputs = inputs + self.pos_enc[:inputs.size(0), :]
        return self.dropout(inputs)

class Encoder(nn.Module):
    def __init__(self, src_vocab_len, model_dim, fc_dim, n_heads, n_enc_layers, pad_idx, dropout, activation):
        super(Encoder, self).__init__()
        self.src_embeddings = nn.Embedding(src_vocab_len, model_dim, padding_idx=pad_idx)
        self.pos_encoder = PositionalEncoding(model_dim, dropout)
        enc_layer = nn.TransformerEncoderLayer(model_dim, n_heads, fc_dim, dropout, activation=activation)
        enc_norm = nn.LayerNorm(model_dim)
        self.encoder = nn.TransformerEncoder(enc_layer, n_enc_layers, enc_norm)

    def forward(self, src, src_mask):
        src = self.src_embeddings(src)
        src = self.pos_encoder(src)
        return self.encoder(src.transpose(0,1), src_key_padding_mask=src_mask) #src_mask is giving error.
        # return self.encoder(src)
        
class Decoder(nn.Module):
    def __init__(self, tgt_vocab_len, model_dim, fc_dim, n_heads, n_dec_layers, pad_idx, dropout, activation):
        super(Decoder, self).__init__()
        self.tgt_embeddings = nn.Embedding(tgt_vocab_len, model_dim, padding_idx=pad_idx)
        self.pos_encoder = PositionalEncoding(model_dim, dropout)
        dec_layer = nn.TransformerDecoderLayer(model_dim, n_heads, fc_dim, dropout, activation=activation)
        dec_norm = nn.LayerNorm(model_dim)
        self.decoder = nn.TransformerDecoder(dec_layer, n_dec_layers, dec_norm)

    def forward(self, tgt, enc_encodings, tgt_sqr_mask=None, tgt_mask=None, src_mask=None):
        tgt = self.tgt_embeddings(tgt)
        tgt = self.pos_encoder(tgt)
        if tgt_mask is None:
            output = self.decoder(tgt.transpose(0,1), enc_encodings, tgt_sqr_mask, tgt_key_padding_mask=tgt_mask, memory_key_padding_mask=src_mask)
        else:
            output = self.decoder(tgt.transpose(0,1), enc_encodings, tgt_sqr_mask, tgt_key_padding_mask=tgt_mask, memory_key_padding_mask=src_mask)
        return output

class TransformerModel(nn.Module):
    def __init__(self, src_vocab_len, tgt_vocab_len, tokenizer, model_dim=512, n_heads=8, n_enc_layers=6, 
                n_dec_layers=6, fc_dim=2048, dropout=0.2, activation='relu'):
        super(TransformerModel, self).__init__()
        self.model_dim = model_dim
        self.n_heads = n_heads
        self.tokenizer = tokenizer
        self.tgt_vocab_len = tgt_vocab_len
        self.encoder = Encoder(src_vocab_len, model_dim, fc_dim, n_heads, n_enc_layers, self.tokenizer.src_vocab["[PAD]"], dropout, activation)
        self.decoder = Decoder(tgt_vocab_len, model_dim, fc_dim, n_heads, n_dec_layers, self.tokenizer.tgt_vocab["[PAD]"], dropout, activation)
        self.out = nn.Linear(model_dim, tgt_vocab_len)
        self._reset_parameters()

    def forward(self, src, tgt, device):
        # src, tgt have shape (batch, seq_len)
        assert src.size(0) == tgt.size(0), "The batch size of source and target sentences should be equal."
        src_mask = get_src_mask(src, self.tokenizer.src_vocab["[PAD]"])
        tgt_mask = get_src_mask(tgt, self.tokenizer.tgt_vocab["[PAD]"])
        tgt_sqr_mask = get_tgt_mask(tgt)
        enc_encodings = self.encoder(src, src_mask.to(device)) 
        output = self.decoder(tgt, enc_encodings, tgt_sqr_mask.to(device), 
                    tgt_mask=tgt_mask.to(device), src_mask=src_mask.to(device))
        output = self.out(output)
        return output.transpose(0,1).contiguous().view(-1, output.size(-1))

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p) 

def get_src_mask(src_tensor, src_pad_id):
    mask = src_tensor != src_pad_id   
    return mask

def get_tgt_mask(tgt_tensor):
    seq_len = tgt_tensor.size(-1)
    mask = (torch.triu(torch.ones(seq_len, seq_len)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

and my training loop

def train_model(model, optimizer, criterion, scheduler, train_dl, valid_dl, batch_size, epoch, device, checkpt_path, best_model_path, beam_size, max_decoding_time_step):
    eval_loss = float('inf')
    start_epoch = 0
    if os.path.exists(checkpt_path):
        model, optimizer, eval_loss, start_epoch = load_checkpt(model, checkpt_path, device, optimizer)
        print(f"Loading model from checkpoint with start epoch: {start_epoch} and loss: {eval_loss}")

    best_eval_loss = eval_loss
    print("Model training started...")
    for epoch in range(start_epoch, epoch):
        print(f"Epoch {epoch} running...")
        epoch_start_time = time.time()
        epoch_train_loss = 0
        epoch_eval_loss = 0
        model.train()
        bleu_score = 0
        for batch in train_dl:
            src_tensor, tgt_tensor, _, _ = model.tokenizer.encode(batch, device, return_tensor=True)
            src_tensor = src_tensor.transpose(0,1)
            tgt_tensor = tgt_tensor.transpose(0,1)
            trg_input = tgt_tensor[:, :-1]
            targets = tgt_tensor[:, 1:].contiguous().view(-1)
            
            optimizer.zero_grad()
            preds = model(src_tensor, trg_input.to(device), device)

            loss = criterion(preds, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            epoch_train_loss += loss.item()/BATCH_SIZE
            print(epoch_train_loss)

The epoch train loss is always nan.

If the invalid values is created in the forward pass, you could use e.g. forward hooks to check all intermediate outputs for NaNs and Infs (have a look at this post to see an example usage). On the other hand, if you think that the backward pass might create invalid gradients, which would then create invalid parameters, you could use torch.autograd.set_detect_anomaly(True) in your code to get more information about the failing layer.

Thank you @ptrblck M encoder was giving some nan values due to an incorrect mask. I modified my mask and it worked.

Can you elaborate what your mistake was? Was it from the padding mask or the lookahead mask?

In my case, It was because of the incorrect target sentence mask.

The padding mask or the triangular mask you gave for causality? And how did you fix it? What was the problem?

My padding mask was incorrect due to some reason. You can try checking for nan values in every step inside your model’s encoder and decoder. If there is any nan value then check for the possible reasons. In my case, my decoder was giving nan values and I changed the implementation and it worked for me.

1 Like

I found out my problem. My mask was the opposite of what it should have been. Pytorch’s Transformer model requires you to mask padded indices in a way that they become true while non-padded tokens are assigned a false value in the corresponding mask.

3 Likes

I just want to add, that in my case, I forgot to add the BOS and EOS padding to the corpus, which could also result in the nan loss (yes, a silly beginner mistake:( )