# Why F.scaled_dot_product_attention output in this case differs with normal attention

I have created a minimal script to reproduce the issue

I am trying to re-implement the normal attention function without math backgrounds, I can’t make F.scaled_dot_product_attention to output the near exact normal_attention output, after a whole day debugging, I need some help

#!/usr/bin/env python3

import torch

def normal_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""
#### Normal Attention

:param q: are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]
:param k: are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]
:param v: are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]
"""

# Split them to heads of shape [batch_size, seq_len, n_heads, d_head]
q = q.view(*q.shape[:2], 8, -1)
k = k.view(*k.shape[:2], 8, -1)
v = v.view(*v.shape[:2], 8, -1)

# Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
attn = torch.einsum('bihd,bjhd->bhij', q, k)

# Compute softmax
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
half = attn.shape[0] // 2
attn[half:] = attn[half:].softmax(dim=-1)
attn[:half] = attn[:half].softmax(dim=-1)

# Compute attention output
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
out = torch.einsum('bhij,bjhd->bihd', attn, v)
# Reshape to [batch_size, height * width, n_heads * d_head]
out = out.reshape(*out.shape[:2], -1)
# Map to [batch_size, height * width, d_model] with a linear layer
return out

import torch.nn.functional as F

query = torch.rand(32, 128, 512, dtype=torch.float16, device="cuda")
key = torch.rand(32, 128, 512, dtype=torch.float16, device="cuda")
value = torch.rand(32, 128, 512, dtype=torch.float16, device="cuda")

result_normal = normal_attention(query, key, value)

query = query.view(32, 128, 8, 64).transpose(1, 2)
key = key.view(32, 128, 8, 64).transpose(1, 2)
value = value.view(32, 128, 8, 64).transpose(1, 2)

with torch.backends.cuda.sdp_kernel(enable_math=False):
result_torch = F.scaled_dot_product_attention(query,key,value, attn_mask=None, dropout_p=0.0, is_causal=False)

result_torch = result_torch.transpose(1, 2).reshape(32, 128, 512)

print(result_torch.shape)
print(result_normal.shape)
print((result_normal - result_torch).abs().max())



result_torch and result_normal are not even close to each other

The above normal function is from github diffusers 0.8.0 repository and placed with constant numbers instead of variables, the usage of F.scaled_dot_product_attention is my best guess without understanding math

Additional explanation for the normal_attention hack: it is a hack to use lesser vram to fit in small gpus

normal_attention is around max ~0.06 from the standard attention implementation, and xformers library has a scaled_dot_product_attention as well, it produce the ~0.0005 near exact output with the original attention and is also ~0.06 with this normal attention hack

however the scaled_dot_product_attention in xformers when took 4 dim inputs, is near identical with F.scaled_dot_product_attention

So the problem is how to let F.scaled_dot_product_attention take 3 dim inputs and generate reasonable near exact output

I’ll put a standard attention without hack here, in case the ~0.06 error seems annoying

    def _attention(self, query, key, value):
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
attention_probs = attention_scores.softmax(dim=-1)
# compute attention output

hidden_states = torch.bmm(attention_probs, value)

# reshape hidden_states

q2 = reshape_heads_to_batch_dim(query)

I don’t know why and what reshape_heads_to_batch_dim needs to be added for def normal_attention , but I am glad I finally make it work at last