I have created a minimal script to reproduce the issue
I am trying to re-implement the normal attention function without math backgrounds, I can’t make F.scaled_dot_product_attention to output the near exact normal_attention output, after a whole day debugging, I need some help
#!/usr/bin/env python3
import torch
def normal_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""
#### Normal Attention
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
"""
# Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
q = q.view(*q.shape[:2], 8, -1)
k = k.view(*k.shape[:2], 8, -1)
v = v.view(*v.shape[:2], 8, -1)
# Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
attn = torch.einsum('bihd,bjhd->bhij', q, k)
# Compute softmax
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
half = attn.shape[0] // 2
attn[half:] = attn[half:].softmax(dim=-1)
attn[:half] = attn[:half].softmax(dim=-1)
# Compute attention output
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
out = torch.einsum('bhij,bjhd->bihd', attn, v)
# Reshape to `[batch_size, height * width, n_heads * d_head]`
out = out.reshape(*out.shape[:2], -1)
# Map to `[batch_size, height * width, d_model]` with a linear layer
return out
import torch.nn.functional as F
query = torch.rand(32, 128, 512, dtype=torch.float16, device="cuda")
key = torch.rand(32, 128, 512, dtype=torch.float16, device="cuda")
value = torch.rand(32, 128, 512, dtype=torch.float16, device="cuda")
result_normal = normal_attention(query, key, value)
query = query.view(32, 128, 8, 64).transpose(1, 2)
key = key.view(32, 128, 8, 64).transpose(1, 2)
value = value.view(32, 128, 8, 64).transpose(1, 2)
with torch.backends.cuda.sdp_kernel(enable_math=False):
result_torch = F.scaled_dot_product_attention(query,key,value, attn_mask=None, dropout_p=0.0, is_causal=False)
result_torch = result_torch.transpose(1, 2).reshape(32, 128, 512)
print(result_torch.shape)
print(result_normal.shape)
print((result_normal - result_torch).abs().max())
result_torch and result_normal are not even close to each other
The above normal function is from github diffusers 0.8.0 repository and placed with constant numbers instead of variables, the usage of F.scaled_dot_product_attention is my best guess without understanding math
Thanks in advance