I am unable to make the scaled_dot_product_attention
function numerically stable, so that it
returns the same values as a manual implementation of attention.
I provide a colab snippet that shows the problem.
One time i calculate attention manually and one time I use scaled_dot_product_attention
.
Even with
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=False):
the values don’t match.
A little bit of floating point error is to be expected, but as you can see, I get a difference of 3e-07 which is larger than the standard 1e-07.
Is this much deviation expected?
Mismatched elements: 76893 / 172032 (44.7%)
Max absolute difference: 2.9802322e-07
Max relative difference: 5.89578e-07
import torch
import numpy as np
# should be unneccessary, not running on GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.deterministic = True
# compares if two tensors are equal
def assert_equal(output1, output2):
a1, a2 = output1.detach().numpy(), output2.detach().numpy()
np.testing.assert_allclose(a1, a2, rtol=1e-07)
# regular values for Vision Transformer
head_dim = 64
no_heads = 12
no_input_vectors = 14*16
#random input in shape batchsize, number input vectors, embedding dimension
input = torch.rand(1, no_input_vectors, head_dim*no_heads)
# First, calculate attention manually
x_manual = torch.clone(input)
x_manual = x_manual.reshape(1, no_input_vectors, no_heads, head_dim).permute(0, 2, 1, 3)
q, k, v = x_manual, x_manual, x_manual
k_t = k.transpose(-2, -1)
s_dot_product = ( q @ k_t) * (head_dim**-0.5)
attn = s_dot_product.softmax(dim=-1)
x_manual = attn @ v
## Now do it with the scaled_dot_product function
x_auto = torch.clone(input)
# Disable all "improvements"
# Also doesn't work with enable_math=True
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=False):
x_auto = x_auto.reshape(1, no_input_vectors, no_heads, head_dim).permute(0, 2, 1, 3)
x_auto = torch.nn.functional.scaled_dot_product_attention(x_auto,x_auto,x_auto)
assert assert_equal(x_manual, x_auto)