Python implementation of scaled_dot_product_attention() yields different outcome than `aten::scaled_dot_product_attention`

Hi everyone,

I am trying to customize the scaled_dot_product_attention() which is called in the multi_head_attention_forward() as part of the MultiheadAttention class. The original scaled_dot_product_attention() is part of torch._C._nn and (in my understanding) comes down to aten::scaled_dot_product_attention in onnx.

When I copy the source code of the MultiheadAttention class, multi_head_attention_forward() function, and use the Python code for scaled_dot_product_attention() (provided by PyTorch in a comment block), the training behavior of my ViT on ImageNet changes drastically. The snippet can be found below:

def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)
  
    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attn_mask + attn_bias
  
    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
  
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

However when I run my training code for the ViT, with the only difference being calling nn.MultiheadAttention() or the copied class with the snippet above for scaled_dot_product_attention(), the training behaviour changes completely. See screenshot below:

I want to emphasize that I did not do any customization yet and simply copied the commented source code above.

Am I correct to assume that this difference in training behavior can be attributed to using the Python implementation and that if I want to maintain good performance I have to use the optimized implementations? Or is the Python implementation supposed to give the same results but slower?

Subsequently, to customize the attention mechanism, do I have to clone PyTorch, change the .cpp files as discussed here, and build PyTorch from source?

I do not have any experience with this so if there is a PyTorch-based workaround I would love to find out. I looked at jit.trace but what I could find that yields a static computational graph which is fine for inference, but not ideal for training.

Thank you for any help in advance!

torch.nn. functional.scaled_dot_product_attention This function is NOT equivalent to the definition in the document. That’s all I know