Transformer Decoder Peeking into the Future Despite Subsequent Mask?

Hey all,

I’ve been playing around with the Transformer model in PyTorch and bumped into something a bit weird.

I set up a transformer with a proper subsequent mask for the decoder to make sure it’s not peeking into the future when generating the output. But when I feed the whole target sequence to the model and compare it to when I feed only the first part of the sequence, I get different results for that first part.

Here’s what I’m doing in code:

transformer_model = nn.Transformer(
    nhead=1, 
    d_model=4,
    num_encoder_layers=1,
    num_decoder_layers=1,
    batch_first=True,
    dropout=0.0, 
)
src = torch.rand((1, 10, 4))
tgt = torch.rand((1, 20, 4))

transformer_model.eval()

out = transformer_model(
    src, 
    tgt,
    tgt_mask=transformer_model.generate_square_subsequent_mask(tgt.size(1)),
)

out_limited = transformer_model(
    src, 
    tgt[:, :10, :],
    tgt_mask=transformer_model.generate_square_subsequent_mask(tgt[:, :10, :].size(1)),
)

print((out[:, :10]-out_limited).mean())

From my understanding of how Transformers and the subsequent mask work, I’d expect the output of the first 10 positions in out to be identical to out_limited. But it looks like the future tokens in tgt are somehow influencing the output of the initial tokens, even with the mask in place.

I’ve got a hunch that this could be because of how the mask is applied in the decoder - after the attentions are calculated but before they’re summed for the output. This might let the future tokens kinda “leak” into the past tokens through the softmax normalization in the attention function.

Just wondering if anyone else has noticed this or if I’m missing something here? Any ideas for a workaround?

Cheers!