Trouble when using Scaled Dot Product Attention

i am trying to modify Roberta Self Attention class from HuggingFace’s Transformers Lib with Flash Attention, however i received an error from Scaled Dot Product Attention that make me can’t do it. I am working with A100 and torch 2.2.1
This is the error

No available kernel.  Aborting execution.

Here is my implementation

import torch
import torch.nn.functional as F
import torch.nn as nn
import math
from transformers.models.roberta.modeling_roberta import RobertaSelfAttention, RobertaModel
from transformers import  AutoConfig, AutoTokenizer
from typing import Tuple, Optional
import copy

class RobertaFlashAttention(nn.Module):
    def __init__(self, config, position_embedding_type=None) -> None:
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        
        mixed_query_layer = self.query(hidden_states)
        is_cross_attention = encoder_hidden_states is not None

        if is_cross_attention and past_key_value is not None:
            # reuse k,v, cross_attentions
            key_layer = past_key_value[0]
            value_layer = past_key_value[1]
            attention_mask = encoder_attention_mask
        elif is_cross_attention:
            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
            attention_mask = encoder_attention_mask
        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))

        query_layer = self.transpose_for_scores(mixed_query_layer)

        use_cache = past_key_value is not None
        if self.is_decoder:
            past_key_value = (key_layer, value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        # attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        # attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        # if attention_mask is not None:
        #     # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
        #     attention_scores = attention_scores + attention_mask

        # # Normalize the attention scores to probabilities.
        # attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # # This is actually dropping out entire tokens to attend to, which might
        # # seem a bit unusual, but is taken from the original Transformer paper.
        # attention_probs = self.dropout(attention_probs)

        # # Mask heads if we want to
        # if head_mask is not None:
        #     attention_probs = attention_probs * head_mask

        # context_layer = torch.matmul(attention_probs, value_layer)

        with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True):
            F.scaled_dot_product_attention(query_layer,key_layer,value_layer, attn_mask=attention_mask)

        print(context_layer.size())

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs

config = AutoConfig.from_pretrained("vinai/phobert-base-v2")
roberta_self_attention = RobertaSelfAttention(config=config)
 
for i, layer in enumerate(model.encoder.layer):
    flash_self_attn = RobertaFlashAttention(config)
    flash_self_attn.query = layer.attention.self.query
    flash_self_attn.key = layer.attention.self.key
    flash_self_attn.value = layer.attention.self.value

    flash_self_attn.query_global = copy.deepcopy(layer.attention.self.query)
    flash_self_attn.key_global = copy.deepcopy(layer.attention.self.key)
    flash_self_attn.value_global = copy.deepcopy(layer.attention.self.value)

    layer.attention.self = flash_self_attn
model = model.to("cuda")
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2")
inputs = tokenizer(text=["con_chó vừa béo vừa kiêu"], return_tensors="pt")
for key in inputs.keys():
    inputs[key] = inputs[key].to("cuda")
model(**inputs)