Scaled_dot_product_attention is not numerically stable

I am unable to make the scaled_dot_product_attention function numerically stable, so that it
returns the same values as a manual implementation of attention.
I provide a colab snippet that shows the problem.

One time i calculate attention manually and one time I use scaled_dot_product_attention.

Even with

  with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=False):

the values don’t match.

A little bit of floating point error is to be expected, but as you can see, I get a difference of 3e-07 which is larger than the standard 1e-07.
Is this much deviation expected?

Mismatched elements: 76893 / 172032 (44.7%)
Max absolute difference: 2.9802322e-07
Max relative difference: 5.89578e-07

import torch
import numpy as np

# should be unneccessary, not running on GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.deterministic =  True


# compares if two tensors are equal
def assert_equal(output1, output2):
  a1, a2 = output1.detach().numpy(), output2.detach().numpy()
  np.testing.assert_allclose(a1, a2, rtol=1e-07)

# regular values for Vision Transformer
head_dim = 64
no_heads = 12
no_input_vectors = 14*16

#random input in shape batchsize, number input vectors, embedding dimension
input = torch.rand(1, no_input_vectors, head_dim*no_heads)

# First, calculate attention manually

x_manual = torch.clone(input)
x_manual = x_manual.reshape(1, no_input_vectors, no_heads, head_dim).permute(0, 2, 1, 3)
q, k, v = x_manual, x_manual, x_manual
k_t = k.transpose(-2, -1)
s_dot_product = ( q @ k_t) * (head_dim**-0.5)
attn = s_dot_product.softmax(dim=-1)
x_manual = attn @ v



## Now do it with the scaled_dot_product function

x_auto = torch.clone(input)
# Disable all "improvements"
# Also doesn't work with enable_math=True
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=False):
  x_auto = x_auto.reshape(1, no_input_vectors, no_heads, head_dim).permute(0, 2, 1, 3)
  x_auto = torch.nn.functional.scaled_dot_product_attention(x_auto,x_auto,x_auto)


assert assert_equal(x_manual, x_auto)

Yes, your error is still in the expected range. Could you explain where the hard limit of 1e-7 comes from?

Thanks for the reply, good question

np.testing.assert_allclose uses as default value 1e-07 so I assumed, this is the accuracy we want to achieve if two calculations should be identical. But this might be wrong.
I was using this value to test if a custom transformer implementation is identically to a timm.VisionTransformer but now, timm switched to F.scaled_dot_product_attention so my test is failing.
I guess i can set the value to 1e-06.

Thanks for the background. Yes, I think default values make sense, but I would not see them as a hard limiting factor. E.g. even accumulating 100 * 100 elements on the CPU in different orders already yields a higher error:

x = torch.randn(100, 100)
s1 = x.sum()
s2 = x.sum(0).sum(0)

print((s1 - s2).abs().max())
# tensor(4.3869e-05)
print((s1 - s2).abs()/s2.abs())
# tensor(1.4508e-06)