Thanks. I think we have a slight misunderstanding. My point was, that at least for the torch version I have ('2.1.0+cu121'
) some of the kernels behind scaled_dot_product_attention
mandate 4 dimensional inputs.
Sample code. When selecting memory efficient attention (V100 GPU), using anything except four-dim. inputs leads to a runtime error.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
assert torch.cuda.is_available()
device = torch.cuda.current_device()
torch.manual_seed(42)
mu, sigma = 100., 1.0
# Four-dim. tensors -- works fine
q = torch.normal(mu, sigma, size=(4, 6, 8, 2), device=device)
k = torch.normal(mu, sigma, size=(4, 6, 8, 2), device=device)
v = torch.normal(mu, sigma, size=(4, 6, 8, 2), device=device)
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
attn = F.scaled_dot_product_attention(q, k, v)
print('Success for 4D')
# Three-dim. tensors -- yields runtime error
q = torch.normal(mu, sigma, size=(4, 6, 8), device=device)
k = torch.normal(mu, sigma, size=(4, 6, 8), device=device)
v = torch.normal(mu, sigma, size=(4, 6, 8), device=device)
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
attn = F.scaled_dot_product_attention(q, k, v)
Error message:
/tmp/ipykernel_746318/3205699521.py:28: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:367.)
attn = F.scaled_dot_product_attention(q, k, v)
/tmp/ipykernel_746318/3205699521.py:28: UserWarning: Both fused kernels requires query, key and value to be 4 dimensional, but got Query dim: 3, Key dim: 3, Value dim: 3 instead. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:273.)
attn = F.scaled_dot_product_attention(q, k, v)
/tmp/ipykernel_746318/3205699521.py:28: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:369.)
attn = F.scaled_dot_product_attention(q, k, v)
/tmp/ipykernel_746318/3205699521.py:28: UserWarning: Flash attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:425.)
attn = F.scaled_dot_product_attention(q, k, v)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 28
25 v = torch.normal(mu, sigma, size=(4, 6, 8), device=device)
27 with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
---> 28 attn = F.scaled_dot_product_attention(q, k, v)
RuntimeError: No available kernel. Aborting execution.