I am working with a decoder-only transformer for sequence modeling with a relatively standard architecture and custom vocabulary. I am coding it with the PyTorch (2.0.1) classes:
decoder_layer = nn.TransformerEncoderLayer(d_model, nhead=nhead,
dim_feedforward=1024, dropout=dropout, activation='gelu', layer_norm_eps=1e-05, batch_first=True, norm_first=True)
self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=n_layers)
I am using batched inputs of differing sequence lengths, so the beginning of each sequence is padded as necessary. To handle this, the forward function of my class accepts a padding mask input of shape (batch_size, seq_len)
, where each padding sequence begins with some number of True
values followed by False
values for the rest. This was originally passed to the src_key_padding_mask
argument of self.decoder
. Then, to create a causal decoder, the regular mask argument was given the standard causal mask
attn_mask = torch.arange(seq_len, device=device).reshape(1, -1) > torch.arange(seq_len, device=device).reshape(-1, 1)
Unfortunately, when I ran this network, I was getting the incorrect result that the last output vector of every single sequence was nan. So my first question is this – Is this intended behavior, or did I implement this incorrectly? The documentation is not terribly clear.
After a long time debugging, I tried manually combining the two masks and removing the src_key_padding_mask
argument. So, the input was the following:
attn_mask = torch.arange(seq_len, device=device).reshape(1, -1) > torch.arange(seq_len, device=device).reshape(-1, 1)
attn_mask = attn_mask.reshape(1, *attn_mask.shape) | (padding_mask.reshape(batch_size, 1, -1))
attn_mask = attn_mask.repeat(1, self.nhead, 1).reshape(batch_size * self.nhead, seq_len, -1)
This fixed the output, so that the last vector was no longer null. However, now several output vectors at the beginning became null, though that makes sense, as those vectors are given no attention in the attention layers, not even from themselves.
Unfortunately, this created a new problem, which is that my nn.CrossEntropyLoss function could not run autograd without raising nan errors. However, I didn’t expect this to happen, since every output vector with nan corresponded to the “pad” class, which is given to nn.CrossEntropyLoss as ignore_index. So my second question is, do the vectors with ignored indices in CrossEntropyLoss still get put through autograd? Do they affect the gradient output? Is there an easy way to stop this?
Finally, I resolved this issue by modifying the mask yet again to be
attn_mask = torch.arange(seq_len, device=device).reshape(1, -1) > torch.arange(seq_len, device=device).reshape(-1, 1)
attn_mask = attn_mask.reshape(1, *attn_mask.shape) | (padding_mask.reshape(batch_size, 1, -1))
indices = torch.arange(seq_len, device=device)
attn_mask[:, indices, indices] = False
attn_mask = attn_mask.repeat(1, self.nhead, 1).reshape(batch_size * self.nhead, seq_len, -1)
This allows every padded token to attend to itself, removing the null values. Now autograd completes successfully, and the model does train, but this feels extremely hacky and unintended, and I still don’t know if the padded vectors are affecting the training process. What is the clean and desired way to fix this issue?