Large output differences when using torch.no_grad() with TransformerEncoder (FlashAttention?)

I’m noticing large differences (elements differ by more than 1) in tensors produced by running a forward pass with a TransformerEncoder with and without torch.no_grad. I believe this is caused by one using FlashAttention and the other not.

Is this just a consequence of accumulated floating point differences? I don’t observe the behavior when no mask is provided (max abs difference is ~1e-6). Torch version is 2.6.0

import torch.nn as nn
import torch

torch.manual_seed(10)

SEQ_LENGTH=30
EMBEDDING_DIM=128
BATCH_SIZE=16

fake_tokens = torch.ones(BATCH_SIZE, SEQ_LENGTH, EMBEDDING_DIM, requires_grad=True)
fake_mask = torch.ones((BATCH_SIZE, SEQ_LENGTH), dtype=bool).cpu()
fake_mask[:, :-SEQ_LENGTH//4] = 0
encoder_layer = torch.nn.TransformerEncoderLayer(d_model=EMBEDDING_DIM, nhead=4, batch_first=True, dropout=0.0)
encoder = torch.nn.TransformerEncoder(encoder_layer, 8).eval()

fake_out = encoder(fake_tokens, src_key_padding_mask=fake_mask, is_causal=False)
with torch.no_grad():
    fake_out_no_grad = encoder(fake_tokens, src_key_padding_mask=fake_mask, is_causal=False)
    
print(torch.abs(fake_out - fake_out_no_grad).max())

Output:

tensor(2.8289, grad_fn=<MaxBackward1>)