Why F.scaled_dot_product_attention output in this case differs with normal attention

I have created a minimal script to reproduce the issue

I am trying to re-implement the normal attention function without math backgrounds, I can’t make F.scaled_dot_product_attention to output the near exact normal_attention output, after a whole day debugging, I need some help

#!/usr/bin/env python3

import torch

def normal_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
    """
    #### Normal Attention

    :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
    :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
    :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
    """

    # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
    q = q.view(*q.shape[:2], 8, -1)
    k = k.view(*k.shape[:2], 8, -1)
    v = v.view(*v.shape[:2], 8, -1)

    # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
    attn = torch.einsum('bihd,bjhd->bhij', q, k)

    # Compute softmax
    # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
    half = attn.shape[0] // 2
    attn[half:] = attn[half:].softmax(dim=-1)
    attn[:half] = attn[:half].softmax(dim=-1)

    # Compute attention output
    # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
    out = torch.einsum('bhij,bjhd->bihd', attn, v)
    # Reshape to `[batch_size, height * width, n_heads * d_head]`
    out = out.reshape(*out.shape[:2], -1)
    # Map to `[batch_size, height * width, d_model]` with a linear layer
    return out

import torch.nn.functional as F

query = torch.rand(32, 128, 512, dtype=torch.float16, device="cuda")
key = torch.rand(32, 128, 512, dtype=torch.float16, device="cuda")
value = torch.rand(32, 128, 512, dtype=torch.float16, device="cuda")

result_normal = normal_attention(query, key, value)

query = query.view(32, 128, 8, 64).transpose(1, 2)
key = key.view(32, 128, 8, 64).transpose(1, 2)
value = value.view(32, 128, 8, 64).transpose(1, 2)

with torch.backends.cuda.sdp_kernel(enable_math=False):
    result_torch = F.scaled_dot_product_attention(query,key,value, attn_mask=None, dropout_p=0.0, is_causal=False)

result_torch = result_torch.transpose(1, 2).reshape(32, 128, 512)

print(result_torch.shape)
print(result_normal.shape)
print((result_normal - result_torch).abs().max())

result_torch and result_normal are not even close to each other

The above normal function is from github diffusers 0.8.0 repository and placed with constant numbers instead of variables, the usage of F.scaled_dot_product_attention is my best guess without understanding math

Thanks in advance

Additional explanation for the normal_attention hack: it is a hack to use lesser vram to fit in small gpus

normal_attention is around max ~0.06 from the standard attention implementation, and xformers library has a scaled_dot_product_attention as well, it produce the ~0.0005 near exact output with the original attention and is also ~0.06 with this normal attention hack

however the scaled_dot_product_attention in xformers when took 4 dim inputs, is near identical with F.scaled_dot_product_attention

So the problem is how to let F.scaled_dot_product_attention take 3 dim inputs and generate reasonable near exact output

I’ll put a standard attention without hack here, in case the ~0.06 error seems annoying

    def _attention(self, query, key, value):
        attention_scores = torch.baddbmm(
            torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
            query,
            key.transpose(-1, -2),
            beta=0,
            alpha=self.scale,
        )
        attention_probs = attention_scores.softmax(dim=-1)
        # compute attention output

        hidden_states = torch.bmm(attention_probs, value)

        # reshape hidden_states
        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
        return hidden_states

the self.scale above is 64**-0.5 in this particular case, this is around ~0.06 max error with the normal_attention implementation

I finally figured it out, surprisingly my F.scaled_dot_product_attention is correct, it’s the normal attention implementation can’t be used in this way, need to reshape_heads_to_batch_dim as follows

q2 = reshape_heads_to_batch_dim(query)

I don’t know why and what reshape_heads_to_batch_dim needs to be added for def normal_attention , but I am glad I finally make it work at last