Scaled_dot_product_attn not giving exact same results as normal attention

Hello, I’m trying to substitute my QKV attention function with
torch.nn.functional.scaled_dot_product_attention to benefit from memory efficient attention. However, the result is very slightly different from my own implementation, which mostly follows normal attention calculations.
-Training wise I notice slower convergence and often unstable training when i use sdp, so I’m trying to understand where such differences are coming from
-I checked which method is being invoked and it seems the C++ implementation is the only one that works without erroring out.
-This occurs with or without mixed precision
-I set torch.backends.cudnn.deterministic = True and it did not change the output

My variables are query, key, and value, all of the shape (B,L,E) (1, 32768, 192)

Here is the function I’m using

Here is my older implementation
attn_weight = torch.softmax((query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))), dim=-1)
x2=attn_weight @value

I expect x to be exactly equal to x2 but there are minor differences in some elements of the output tensor.
Am I understanding the sdp function correctly? I would have no issue if the outputs are very slightly off but it seems that this affects training significantly.
Thank you!

You cannot expect exactly bitwise-identical results using floating point precision dtypes without forcing deterministic algorithms due to the limited floating point precision. How large is the relative error between your manual output and the SDP one?

Thank you for the response, I half suspected that the differences were due to floating point cutoffs but good to confirm.
If I take (x-x2).sum() I get a difference of about 0.03 for the tensor sizes (1, 32768, 192)
The main thing that confused me was the different training dynamics when i substituted that sdp function, I had assumed that this was due to the attention being very slightly inexact but it seems thats not the case?

I don’t know if the errors are expected or not since you are directly subtracting the tensors.
What does print((x - x2).abs().max()) and print(x.abs().max(), x2.abs().max()) return?