Does torch.compile use FlashAttention?

I tested the performance of torch.compile on the bert-base model on the A100 machine, and found that the training performance has been greatly improved.
I wonder if flashattention is used under torch.compile. Is there an option to make torch.compile disabled flashattention

Take a look at this tutorial (Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA) — PyTorch Tutorials 2.0.0+cu117 documentation

Specifically take a look at the backend_map dict, pass them in to an sdp_kernel context manager and then torch.compile your model within that scope to disable

But to answer your first question, yes I believe it is enabled by default now with or without torch.compile

Thank you for your reply.
I didn’t use torch.nn.functional.scaled_dot_product_attention and torch.nn.MultiheadAttention, I used the self-attention structure of transformers, looks like this:

mixed_query_layer = self.query(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))

attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)

context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

I wonder if torch.compile will use flash attention in this case? Is it possible to find the generated triton code locally

cc @drisspg who might know how to generate the triton kernels - I know how to do this for inductor only

@dancedpipi, torch.compile() does not generate flash or memory-efficient kernels at this point. They are custom kernels that get called for now if you are using torch.nn.functional.scaled_dot_product_attention or torch.nn.MultiheadAttention.

Does that mean that if I don’t use torch.nn.functional.scaled_dot_product_attention or torch.nn.MultiheadAttention and only use torch.compile, it won’t use flash attention or memory-effective kernel?