I tested the performance of torch.compile
on the bert-base
model on the A100
machine, and found that the training performance has been greatly improved.
I wonder if flashattention
is used under torch.compile
. Is there an option to make torch.compile
disabled flashattention
Take a look at this tutorial (Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA) — PyTorch Tutorials 2.0.0+cu117 documentation
Specifically take a look at the backend_map
dict, pass them in to an sdp_kernel
context manager and then torch.compile
your model within that scope to disable
But to answer your first question, yes I believe it is enabled by default now with or without torch.compile
Thank you for your reply.
I didn’t use torch.nn.functional.scaled_dot_product_attention
and torch.nn.MultiheadAttention
, I used the self-attention structure of transformers, looks like this:
mixed_query_layer = self.query(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
I wonder if torch.compile will use flash attention in this case? Is it possible to find the generated triton code locally
cc @drisspg who might know how to generate the triton kernels - I know how to do this for inductor only
@dancedpipi, torch.compile() does not generate flash or memory-efficient kernels at this point. They are custom kernels that get called for now if you are using torch.nn.functional.scaled_dot_product_attention
or torch.nn.MultiheadAttention
.
Does that mean that if I don’t use torch.nn.functional.scaled_dot_product_attention
or torch.nn.MultiheadAttention
and only use torch.compile, it won’t use flash attention or memory-effective kernel?