Hello,
I have a super small model that I’m using for benchmarking how fast a model with different tricks is:
import torch
import pytorch_lightning as pl
import torch.nn as nn
import time
class TransformerModel(pl.LightningModule):
def __init__(self, dim_model, nheads, dim_feedforward, nlayers, batch_first, dropout):
super(TransformerModel, self).__init__()
self.encoder_layer = nn.TransformerEncoderLayer(
d_model=dim_model,
nhead=nheads,
dim_feedforward=dim_feedforward,
batch_first=batch_first,
dropout=dropout,
layer_norm_eps=1e-12
)
self.encoder = nn.TransformerEncoder(encoder_layer=self.encoder_layer, num_layers=nlayers, enable_nested_tensor=False)
def forward(self, x):
# with torch.backends.cuda.sdp_kernel(
# enable_flash=True, enable_math=True, enable_mem_efficient=True
# ):
x = self.encoder(x)
return x
# Define a function to benchmark the forward pass time
def benchmark_forward(model, input_tensor, num_iterations=100):
model.eval()
total_time = 0.0
with torch.no_grad():
for _ in range(num_iterations):
start_time = time.time()
output = model(input_tensor)
end_time = time.time()
total_time += end_time - start_time
average_time = total_time / num_iterations
print(f"Average forward pass time: {average_time} seconds")
dim_model = 512
nheads = 8
dim_feedforward = 2048
nlayers = 6
batch_first = True
dropout = 0.1
# Initialize the model
model = TransformerModel(dim_model, nheads, dim_feedforward, nlayers, batch_first, dropout)
# Create a sample input tensor (adjust the shape and data as needed)
input_tensor = torch.randn(6, 1000, dim_model) # Batch size of 32, sequence length of 10
# Benchmark the forward pass
benchmark_forward(model, input_tensor)
The point is that I want to use Flash Attention to make my model faster. However, in the documentation of Pytorch 2.0 it appears (TransformerEncoderLayer — PyTorch 2.1 documentation) that Flash Attention is used uniquely during inference, not at training time. Hence, my question is, how can I leverage Flash Attention using the Transformer API of Pytorch? Is it not possible?
They also highlight the benefit of using enable_nested_tensor=True
, even though I found a bug (already reported) and cannot set it to True.
Finally, they also advocate for not using that much attention_mask but in the case of the causal mask. Hence, is it common not using attention mask for the PAD tokens?
Thank you so much!!