Flash Attention in Transformer API


I have a super small model that I’m using for benchmarking how fast a model with different tricks is:

import torch
import pytorch_lightning as pl
import torch.nn as nn
import time

class TransformerModel(pl.LightningModule):
    def __init__(self, dim_model, nheads, dim_feedforward, nlayers, batch_first, dropout):
        super(TransformerModel, self).__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(
        self.encoder = nn.TransformerEncoder(encoder_layer=self.encoder_layer, num_layers=nlayers, enable_nested_tensor=False)

    def forward(self, x):
        # with torch.backends.cuda.sdp_kernel(
        #     enable_flash=True, enable_math=True, enable_mem_efficient=True
        # ): 
        x = self.encoder(x)
        return x

# Define a function to benchmark the forward pass time
def benchmark_forward(model, input_tensor, num_iterations=100):
    total_time = 0.0
    with torch.no_grad():
        for _ in range(num_iterations):
            start_time = time.time()
            output = model(input_tensor)
            end_time = time.time()
            total_time += end_time - start_time

    average_time = total_time / num_iterations
    print(f"Average forward pass time: {average_time} seconds")

dim_model = 512
nheads = 8
dim_feedforward = 2048
nlayers = 6
batch_first = True
dropout = 0.1

# Initialize the model
model = TransformerModel(dim_model, nheads, dim_feedforward, nlayers, batch_first, dropout)

# Create a sample input tensor (adjust the shape and data as needed)
input_tensor = torch.randn(6, 1000, dim_model)  # Batch size of 32, sequence length of 10

# Benchmark the forward pass
benchmark_forward(model, input_tensor)

The point is that I want to use Flash Attention to make my model faster. However, in the documentation of Pytorch 2.0 it appears (TransformerEncoderLayer — PyTorch 2.1 documentation) that Flash Attention is used uniquely during inference, not at training time. Hence, my question is, how can I leverage Flash Attention using the Transformer API of Pytorch? Is it not possible?

They also highlight the benefit of using enable_nested_tensor=True, even though I found a bug (already reported) and cannot set it to True.

Finally, they also advocate for not using that much attention_mask but in the case of the causal mask. Hence, is it common not using attention mask for the PAD tokens?

Thank you so much!!