Troubles with scaled_dot_product_attention

Hello, I try to implement my own neural machine translition model with Flash Attention (use scaled_dot_product_attention from torch.nn.functional)
I have two troubles with it:

  1. When I wanna use dtype=torch.float16, I have the following error:
RuntimeError: "baddbmm_with_gemm" not implemented for 'Half'
  1. When I try to use device = ‘cuda’ I have this error:
RuntimeError: No available kernel.  Aborting execution.

For 2 point: I found this solution, upgrade pytorch for version from post, but it didn’t help me.

Leave here the piece of code with error:

if self.flash:
        with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
            y = F.scaled_dot_product_attention(Q, K, V,
                                              attn_mask=mask,
                                              dropout_p=self.dropout if self.training else 0,
                                              is_causal=True)
        else:
            raise ImportError("PyTorch >= 2.0 must be installed for using Flash Attention")

I can post the full function code if necessary

Which PyTorch version and which GPU are you using?

I use GPU T4 (from google colab)
pytorch version: 2.1.0.dev20230618+cu121

Thank you! Could you post the missing code pieces to create a minimal and executable code snippet, please?

@ptrblck hi!

Attention module code:

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.hidden_dim = config.hidden_dim
        self.num_heads = config.num_heads
        self.dropout = config.dropout

        # support is only in PyTorch >= 2.0
        self.flash = hasattr(F, 'scaled_dot_product_attention')

        self.out_linear = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
        self.resid_dropout = nn.Dropout(config.dropout)

    def forward(self, K, V, Q, mask=None):

        batch_size, hidden_dim = Q.size(0), Q.size(2)
        key_len, value_len, query_len = K.size(1), V.size(1), Q.size(1)

        assert hidden_dim % self.num_heads == 0, "Hidden_dim must be equal to num_heads * head_dim"

        K = K.reshape(batch_size, key_len, self.num_heads, -1).transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
        V = V.reshape(batch_size, value_len, self.num_heads, -1).transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
        Q = Q.reshape(batch_size, query_len, self.num_heads, -1).transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)

        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
                y = F.scaled_dot_product_attention(Q, K, V,
                                                  attn_mask=mask,
                                                  dropout_p=self.dropout if self.training else 0,
                                                  is_causal=True)
        else:
            raise ImportError("PyTorch >= 2.0 must be installed for using Flash Attention")

        y = y.transpose(1, 2).contiguous().view(batch_size, query_len, hidden_dim)
        # y = self.resid_dropout(self.out_linear(y)) + Q
        return self.resid_dropout(self.out_linear(y))

Testing code:

dtype = torch.float16
device = 'cuda'

query = torch.randn(32, 256, 512, device=device)
key = torch.randn(32, 256, 512, device=device)
value = torch.randn(32, 256, 512, device=device)

att = CausalSelfAttention(cfg)
att(key, value, query).shape

When I set device=device in k, v, q parts, raise RuntimeError error “No available kernel.”
When I set dtype=dtype in k, v, q parts, raise RuntimeError error ““baddbmm_with_gemm” not implemented for ‘Half’”

Your code is not executable as cfg is undefined.

My mistake
config:

class config:
    batch_size: int = 16 
    wmt14_size: int = 1E6 
    # model parameters
    hidden_dim: int = 512
    num_heads: int = 8
    dropout: float = 0.1
    # scheduler parameters
    warmup_steps: int = 4000
    #optimizer parameters
    betas: Tuple[float, float] = (0.9, 0.98)
    eps: float = 1.0e-9

cfg = config()

Thank you!
It seems you are running into known limitation as you are using the SDPA backend manually. E.g. I see:

UserWarning: Expected query, key and value to all be of dtype: {Half, BFloat16}. Got Query dtype: float, Key dtype: float, and Value dtype: float instead. 

Using an autocast context manager works for me as it’s casting to the expected dtypes:

att = CausalSelfAttention(cfg).to(device)
with torch.cuda.amp.autocast():
    att(key, value, query).shape

From my understanding the SDPA backends would be selected automatically for you based on some conditions (such as dtype) while the manual implementation could easily fail.

Thak you so much for your explanation!
But I have one more question
with torch.cuda.amp.autocast() will there be a selection both the type of SDPA and the type of tensors or only the first one and we need use dtype = bfloat16 in tensors?

By default autocast will use float16 and works properly on my device. You can specify the mixed-precision dtype to be bfloat16 if you want.

Thank you so much!
With your help this solution works perfectly