Torch.compile _softmax in MultiheadAttention does not return same value as eager

Hello,

I’m trying to implement torch.compile in my program, but I’m encountering reproducibility issues with this script using PyTorch 2.7.1:

import os
from functools import partial

import torch
from torch.nn import MultiheadAttention
from torch.testing import assert_close

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # for deterministic cdist

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True, warn_only=False)

torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
#torch.backends.cuda.enable_math_sdp(False)

strict_assert_close = partial(assert_close, atol=0,rtol=0)
device  = "cuda"

def fn():
    torch._dynamo.reset()

    self_attn_builder = partial(MultiheadAttention,embed_dim=4,num_heads=1,dropout=0.1,bias=True,batch_first=True)

    torch.manual_seed(0)
    tf_raw = self_attn_builder().to(device)

    torch.manual_seed(0)
    tf_compiled = torch.compile(self_attn_builder().to(device),fullgraph=True)

    torch.manual_seed(0)
    tranfomer_input = torch.randn(torch.Size([128, 50, 4])).to(device)

    torch.manual_seed(0)
    output_1 = tf_raw(tranfomer_input,tranfomer_input,tranfomer_input,attn_mask=None,key_padding_mask=None)

    torch.manual_seed(0)
    output_compiled = tf_compiled(tranfomer_input,tranfomer_input,tranfomer_input,attn_mask=None,key_padding_mask=None)

    try:
        strict_assert_close(output_compiled, output_1)
        return True
    except AssertionError:
        return False
from torch._inductor.compiler_bisector import CompilerBisector

CompilerBisector.do_bisect(fn)

I can see that the output of aten._softmax differs between compiled and eager modes, which breaks my reproducibility tests.

Is this expected behavior, or is it a bug? Are there any options to ensure identical results between torch.compile and eager execution?

Thanks for your help!

How large is the absolute and relative error?

No, there is no guarantee different algorithms will return bitwise identical outputs.

Yes, but the error is not so small.

AssertionError: Tensor-likes are not equal!

Mismatched elements: 25600 / 25600 (100.0%)
Greatest absolute difference: 0.4337801933288574 at index (3, 1, 0)
Greatest relative difference: 2057.0537109375 at index (97, 14, 0)

The failure occurred for item [0]

In training, it completely changes my training results.

I can reproduce this larger error, but get the expected numerical mismatches after disabling dropout. I don’t know if random operations are guaranteed to return the same outputs as I would guess the mask sampling differs.

Ah, thank you, I will think that torch.manual_seed(0) will make dropout reproducible with torch compile, and CompilerBisector.do_bisect led me to the wrong path.