Troubles with scaled_dot_product_attention

Hello, I try to implement my own neural machine translition model with Flash Attention (use scaled_dot_product_attention from torch.nn.functional)
I have two troubles with it:

  1. When I wanna use dtype=torch.float16, I have the following error:
RuntimeError: "baddbmm_with_gemm" not implemented for 'Half'
  1. When I try to use device = ‘cuda’ I have this error:
RuntimeError: No available kernel.  Aborting execution.

For 2 point: I found this solution, upgrade pytorch for version from post, but it didn’t help me.

Leave here the piece of code with error:

if self.flash:
        with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
            y = F.scaled_dot_product_attention(Q, K, V,
                                              attn_mask=mask,
                                              dropout_p=self.dropout if self.training else 0,
                                              is_causal=True)
        else:
            raise ImportError("PyTorch >= 2.0 must be installed for using Flash Attention")

I can post the full function code if necessary

1 Like

Which PyTorch version and which GPU are you using?

I use GPU T4 (from google colab)
pytorch version: 2.1.0.dev20230618+cu121

Thank you! Could you post the missing code pieces to create a minimal and executable code snippet, please?

@ptrblck hi!

Attention module code:

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.hidden_dim = config.hidden_dim
        self.num_heads = config.num_heads
        self.dropout = config.dropout

        # support is only in PyTorch >= 2.0
        self.flash = hasattr(F, 'scaled_dot_product_attention')

        self.out_linear = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
        self.resid_dropout = nn.Dropout(config.dropout)

    def forward(self, K, V, Q, mask=None):

        batch_size, hidden_dim = Q.size(0), Q.size(2)
        key_len, value_len, query_len = K.size(1), V.size(1), Q.size(1)

        assert hidden_dim % self.num_heads == 0, "Hidden_dim must be equal to num_heads * head_dim"

        K = K.reshape(batch_size, key_len, self.num_heads, -1).transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
        V = V.reshape(batch_size, value_len, self.num_heads, -1).transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)
        Q = Q.reshape(batch_size, query_len, self.num_heads, -1).transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim)

        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
                y = F.scaled_dot_product_attention(Q, K, V,
                                                  attn_mask=mask,
                                                  dropout_p=self.dropout if self.training else 0,
                                                  is_causal=True)
        else:
            raise ImportError("PyTorch >= 2.0 must be installed for using Flash Attention")

        y = y.transpose(1, 2).contiguous().view(batch_size, query_len, hidden_dim)
        # y = self.resid_dropout(self.out_linear(y)) + Q
        return self.resid_dropout(self.out_linear(y))

Testing code:

dtype = torch.float16
device = 'cuda'

query = torch.randn(32, 256, 512, device=device)
key = torch.randn(32, 256, 512, device=device)
value = torch.randn(32, 256, 512, device=device)

att = CausalSelfAttention(cfg)
att(key, value, query).shape

When I set device=device in k, v, q parts, raise RuntimeError error “No available kernel.”
When I set dtype=dtype in k, v, q parts, raise RuntimeError error ““baddbmm_with_gemm” not implemented for ‘Half’”

Your code is not executable as cfg is undefined.

My mistake
config:

class config:
    batch_size: int = 16 
    wmt14_size: int = 1E6 
    # model parameters
    hidden_dim: int = 512
    num_heads: int = 8
    dropout: float = 0.1
    # scheduler parameters
    warmup_steps: int = 4000
    #optimizer parameters
    betas: Tuple[float, float] = (0.9, 0.98)
    eps: float = 1.0e-9

cfg = config()

Thank you!
It seems you are running into known limitation as you are using the SDPA backend manually. E.g. I see:

UserWarning: Expected query, key and value to all be of dtype: {Half, BFloat16}. Got Query dtype: float, Key dtype: float, and Value dtype: float instead. 

Using an autocast context manager works for me as it’s casting to the expected dtypes:

att = CausalSelfAttention(cfg).to(device)
with torch.cuda.amp.autocast():
    att(key, value, query).shape

From my understanding the SDPA backends would be selected automatically for you based on some conditions (such as dtype) while the manual implementation could easily fail.

2 Likes

Thak you so much for your explanation!
But I have one more question
with torch.cuda.amp.autocast() will there be a selection both the type of SDPA and the type of tensors or only the first one and we need use dtype = bfloat16 in tensors?

By default autocast will use float16 and works properly on my device. You can specify the mixed-precision dtype to be bfloat16 if you want.

1 Like

Thank you so much!
With your help this solution works perfectly

1 Like