Scaled_dot_product_attention - correct input shape

The documentation of scaled_dot_product_attention suggests the following dimensions for inputs:

  • query: (N,...,L,E)
  • key: (N,...,S,E)
  • value: (N,...,S,Ev)
    So these are three dimensional. However, when running either the flash or the memory efficient implementation, I get exceptions along the lines of Both fused kernels requires query, key and value to be 4 dimensional, but got Query dim: 3, Key dim: 3, Value dim: 3 instead.

Question: What are the four dimensions? Testing suggests strongly that it’s batch, heads, sequence, embedding which also agrees with various code examples.

However, the closest to actual documentation I find is the documentation for xformers, which suggests batch, sequence, heads, embedding (xFormers optimized operators | xFormers 0.0.24 documentation).

The representation contains three dots ... which suggests you can add multiple dimensions in between and since the function works on the last two it doesn’t matter for that purpose.

It is correct, the general order is this, where the number of heads can be any value. Perhaps for some reason you have any other dimension it still doesn’t matter the function will be able to compute the attention scores for the L, S, E dimensions. Example:

>>> k = torch.randn(3, 4, 25, 64) # b, h, L, E
>>> q = torch.randn(3, 4, 28, 64) # b, h, S, E
>>> v = torch.randn(3, 4, 28, 64) # b, h, S, E
>>> x1 = F.scaled_dot_product_attention(k, q, v)
>>> x1.shape
# torch.Size([3, 4, 25, 64])

But even if you add more dimensions before the last two dimensions the computation would not change.

>>> x2 = F.scaled_dot_product_attention(k.view(3, 2, 2, 25, 64), q.view(3, 2, 2, 28, 64), v.view(3, 2, 2, 28, 64))
>>> x2.shape
torch.Size([3, 2, 2, 25, 64])

>>> (x1.flatten() == x2.flatten()).all()
# True

Hope this helps!

Thanks. I think we have a slight misunderstanding. My point was, that at least for the torch version I have ('2.1.0+cu121') some of the kernels behind scaled_dot_product_attention mandate 4 dimensional inputs.

Sample code. When selecting memory efficient attention (V100 GPU), using anything except four-dim. inputs leads to a runtime error.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

assert torch.cuda.is_available()
device = torch.cuda.current_device()

torch.manual_seed(42)

mu, sigma = 100., 1.0

# Four-dim. tensors -- works fine
q = torch.normal(mu, sigma, size=(4, 6, 8, 2), device=device)
k = torch.normal(mu, sigma, size=(4, 6, 8, 2), device=device)
v = torch.normal(mu, sigma, size=(4, 6, 8, 2), device=device)

with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
    attn = F.scaled_dot_product_attention(q, k, v)
print('Success for 4D')

# Three-dim. tensors -- yields runtime error
q = torch.normal(mu, sigma, size=(4, 6, 8), device=device)
k = torch.normal(mu, sigma, size=(4, 6, 8), device=device)
v = torch.normal(mu, sigma, size=(4, 6, 8), device=device)

with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
    attn = F.scaled_dot_product_attention(q, k, v)

Error message:

/tmp/ipykernel_746318/3205699521.py:28: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:367.)
  attn = F.scaled_dot_product_attention(q, k, v)
/tmp/ipykernel_746318/3205699521.py:28: UserWarning: Both fused kernels requires query, key and value to be 4 dimensional, but got Query dim: 3, Key dim: 3, Value dim: 3 instead. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:273.)
  attn = F.scaled_dot_product_attention(q, k, v)
/tmp/ipykernel_746318/3205699521.py:28: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:369.)
  attn = F.scaled_dot_product_attention(q, k, v)
/tmp/ipykernel_746318/3205699521.py:28: UserWarning: Flash attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:425.)
  attn = F.scaled_dot_product_attention(q, k, v)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 28
     25 v = torch.normal(mu, sigma, size=(4, 6, 8), device=device)
     27 with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
---> 28     attn = F.scaled_dot_product_attention(q, k, v)

RuntimeError: No available kernel.  Aborting execution.

Hi @jws ,

Did you ever find a solution to this problem? I’m encountering the same issue, although I have 5 dimensional query, key, and value vectors (vs. 4).

Thanks!