Hi,
I am training a decoder-style language model using pytorch’s TransformerEncoder. Following huggingface and generally what I’ve found online, I am training using right padding and a causal mask, and during inference I use left padding. My training is working well and if I use teacher forcing during evaluation (i.e. same set up as training), I get what I expect. If I use my inference decoding during evaluation (no teacher forcing) with a batch size of 1, my system works as expected. However, I quickly run into an issue when using left padded sequences during inference decoding. If I include both the causal mask + an attention mask (over the pad tokens), it result in NaN for all sequences that are padded (I think this is because it create tokens that can’t attend to anything with -inf attention which leads to a NaN that then propagates). If I do not include the causal mask during inference, then I get a different output then during (which makes sense, since it is a different set up than training).
So my question is: How can I apply both a causal mask (i.e. kwarg: mask) and an attention mask (i.e. kwarg: src_key_padding_mask) when using left padding?
Thanks in advance for any help.
The snippet of code I am using:
def forward(self, input_ids: th.Tensor, attention_mask: th.Tensor):
# Inputs are input_ids and attention_mask, both are th.Tensors
using_left_pad = th.any(attention_mask[:, 0] == 0)
if using_left_pad:
# Adjust position ids to match training which uses right padding
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
# Encodes ids to embeddings, adds positional encodings, and standard normalization
embeddings, attention_mask = self.preprocessor(input_ids, attention_mask, position_ids=position_ids)
batch_size, seq_len, d_model = embeddings.shape
# Create causal mask
causal_mask = self.generate_square_subsequent_mask(seq_len)
# Attention mask comes from huggingface tokenizer, so requires the 1 - to match the pytorch format
output = self.decoder(embeddings, mask=causal_mask, is_causal=True,
src_key_padding_mask=(1 - attention_mask).bool())
# With right padded inputs, there is no issue, with batch size of 1, there is no issue
# With left padded inputs where batch size > 1, output is all NaNs on padded sequences