What has been going on with the masks and 'nan' in transformers

Hi @ptrblck

pad_mask = torch.zeros(10, dtype=torch.bool)
pad_mask[6:] = True

causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
causal_mask = causal_mask.float().masked_fill(causal_mask==1, float('-inf'))

Basically causal_mask is the subsequent mask.

Now I have 2 issues:

  1. ‘nan’ inference issue. I do not know why is it happening as my data is properly normalized and scaled (between 0.0 to 1.0) and learning rate also seems to be okay. My decoder is just a simple linear layer though. Also, this model has been working perfectly for a smaller subset of data (approx. 2M rows but real data is bigger than 500x of this). I know that multiplying 0.0 with float(‘inf’) also renders ‘nan’, can it be because of that?
  2. UserWarning: Support for mismatched src_key_padding_mask and mask is deprecated. Use same type for both instead.
class TransformerModel(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_layers, dropout, seq_len):
        super(TransformerModel, self).__init__()
        self.seq_len = seq_len
        
        # Initializing objects
        self.encoder = nn.Linear(input_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout=dropout, max_len=self.seq_len)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.decoder = nn.Linear(d_model, 1)

    def forward(self, x, xlens=None, causal_mask=None, pad_mask=None):
        """
       For this forum purpose assume that the mask is comming in correct form but created from the code above. 
        x: Tensor, shape [batch_size, seq_len, number_of_features]
        """
        
        x = self.encoder(x)        
        x = self.pos_encoder(x)
        x = self.transformer_encoder(x, mask=causal_mask, src_key_padding_mask=pad_mask, is_causal=True)
        x = self.decoder(x)
        x = x.squeeze(-1)
        return x

Should I change my decoder to few transformer decoder layers. How to fix the issues. My data is quite big so training is quite slow. I am using DDP, so can this ‘nan’ issue occur because of that?

If you have an attention mask that masks out an entire row, that is one way NaNs can arise in scaled dot product attention.

Can you please detail a little more what you mean or give an example how?

I have both the masks as mentioned above, where the causal_mask is shape [S,S] and pad_mask is shape [S]. Also this issue is not occurring with ~2M data rows, and the for ~500M rows its aokay till 10 epochs, but occurs after that.

In the attention’s softmax, if all the logits to softmax are all zero-ed out, then a NaN output would be produced.

That makes sense.

Question: Any idea on how to fix the second point i.e. the User Warning

Hi Soulitzer!

As an aside, that’s not actually true:

>>> torch.zeros (5).softmax (dim = 0)
tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
>>> torch.zeros (5).log_softmax (dim = 0)
tensor([-1.6094, -1.6094, -1.6094, -1.6094, -1.6094])
>>> torch.zeros (0).softmax (dim = 0)
tensor([])
>>> torch.zeros (0).log_softmax (dim = 0)
tensor([])

(To Prateek: I don’t have anything useful to say about the issue in your
original post.)

Best.

K Frank

Good catch! What I meant was zero-ed out post exponentiation:

>>> torch.full((5,), float("-inf")).softmax(dim=0)
tensor([nan, nan, nan, nan, nan])

Thanks @soulitzer @KFrank

I am trying to learn and understand more the reasons behind this issue. I am using DDP for the model to train over 3-4 nodes with 1-2 A100 gpus each, but since each epoch takes like an hour to train, I am just trying to make sure to avoid wasting resources and time while tuning these things.

So I changed few things and right now running a test for few epochs:

  1. Changed the total number of parameters the model was learning (through model_dim, num_heads, num_layers etc). All suggestions on these hyperparams or even a ballpark are welcome.
  2. At some article, I read ‘nan’ issue can be due to exploding gradient, so I reduced the learning rate by 10x. If this helps till certain epochs, I would reduce it further or use gradient clipping (which I want to avoid honestly).
  3. I changed float(“-inf”) to -1e9.

Meanwhile this is training I am trying to understand more on the masks:
I have always used pad_mask and subsequent/causal/look-ahead masks separately. The causal/attn masks have been f16/f32 and pad_mask has been bool, but I get thrown warning son making them same. Possible to give example?