Edited after further debugging and creating an even more minimal example.
Using nn.MultiheadAttention in eval() mode is substantially less efficient than using it in train() mode: in my example, the model uses 2.4GB VRAM in train(), and 10.6GB VRAM in eval(). The issue seems to be the fast_path in nn.MultiheadAttention.forward(), because when I remove the fast_path, the issue is solved.
Minimal example:
import torch
import torch.nn as nn
import torch.nn.functional as F
import gc
def test_mha_directly(mode: str):
"""Test MultiheadAttention directly."""
device = "cuda"
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
gc.collect()
# torch.backends.mha.set_fastpath_enabled(False) # UNCOMMENTING THIS LINE SOLVES THE ISSUE
# Create MHA module
mha = nn.MultiheadAttention(
embed_dim=256,
num_heads=4,
dropout=0,
batch_first=True,
).to(device)
if mode == "train":
mha.train()
else:
mha.eval()
batch_size = 32
seq_len = 2000
# Create inputs
query = torch.randn(batch_size, seq_len, 256, device=device)
key = query
value = query
key_padding_mask = torch.zeros(batch_size, seq_len, device=device, dtype=torch.bool)
attn_mask = torch.zeros(seq_len, seq_len, device=device, dtype=torch.bool)
print(f"Mode: {mode} (single MHA)")
with torch.no_grad():
output, _ = mha(query, key, value,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
need_weights=False)
torch.cuda.synchronize()
peak = torch.cuda.max_memory_allocated() / 1024**2
print(f"Peak memory: {peak:.2f} MB")
del mha, output
torch.cuda.empty_cache()
gc.collect()
return peak
if __name__ == "__main__":
print("Testing with hooks and direct MHA\n")
print("\n" + "=" * 60)
print("Single MultiheadAttention")
print("=" * 60)
train_mha = test_mha_directly("train")
eval_mha = test_mha_directly("eval")
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"Single MHA - Train: {train_mha:.2f} MB")
print(f"Single MHA - Eval: {eval_mha:.2f} MB")
Running on an RTX 4090 with CUDA 12.7. Tried Python 3.12 + PyTorch 2.8.0 and Python 3.14 + PyTorch 2.9.0. Also tried running on a Google Colab instance. With all three setups, the issue persists.
It’s not clear to me why the fast_path would be so much less efficient than the “regular” path. It’s not just less memory efficient, it’s also much slower.