Left padded transformer input with causal mask

I am training a decoder-style language model using pytorch’s TransformerEncoder. Following huggingface and generally what I’ve found online, I am training using right padding and a causal mask, and during inference I use left padding. My training is working well and if I use teacher forcing during evaluation (i.e. same set up as training), I get what I expect. If I use my inference decoding during evaluation (no teacher forcing) with a batch size of 1, my system works as expected. However, I quickly run into an issue when using left padded sequences during inference decoding. If I include both the causal mask + an attention mask (over the pad tokens), it result in NaN for all sequences that are padded (I think this is because it create tokens that can’t attend to anything with -inf attention which leads to a NaN that then propagates). If I do not include the causal mask during inference, then I get a different output then during (which makes sense, since it is a different set up than training).

So my question is: How can I apply both a causal mask (i.e. kwarg: mask) and an attention mask (i.e. kwarg: src_key_padding_mask) when using left padding?

Thanks in advance for any help.

The snippet of code I am using:

def forward(self, input_ids: th.Tensor, attention_mask: th.Tensor):
    # Inputs are input_ids and attention_mask, both are th.Tensors
    using_left_pad = th.any(attention_mask[:, 0] == 0)
    if using_left_pad:
        # Adjust position ids to match training which uses right padding
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
    # Encodes ids to embeddings, adds positional encodings, and standard normalization
    embeddings, attention_mask = self.preprocessor(input_ids, attention_mask, position_ids=position_ids)
    batch_size, seq_len, d_model = embeddings.shape
    # Create causal mask
    causal_mask = self.generate_square_subsequent_mask(seq_len)
    # Attention mask comes from huggingface tokenizer, so requires the 1 - to match the pytorch format
    output = self.decoder(embeddings,  mask=causal_mask, is_causal=True,
                          src_key_padding_mask=(1 - attention_mask).bool())
    # With right padded inputs, there is no issue, with batch size of 1, there is no issue
    # With left padded inputs where batch size > 1, output is all NaNs on padded sequences