nn.MultiheadAttention fast_path substantially less efficient

Edited after further debugging and creating an even more minimal example.

Using nn.MultiheadAttention in eval() mode is substantially less efficient than using it in train() mode: in my example, the model uses 2.4GB VRAM in train(), and 10.6GB VRAM in eval(). The issue seems to be the fast_path in nn.MultiheadAttention.forward(), because when I remove the fast_path, the issue is solved.

Minimal example:

import torch
import torch.nn as nn
import torch.nn.functional as F
import gc

def test_mha_directly(mode: str):
    """Test MultiheadAttention directly."""
    device = "cuda"
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    gc.collect()

    # torch.backends.mha.set_fastpath_enabled(False)   # UNCOMMENTING THIS LINE SOLVES THE ISSUE
    
    # Create MHA module
    mha = nn.MultiheadAttention(
        embed_dim=256,
        num_heads=4,
        dropout=0,
        batch_first=True,
    ).to(device)
    
    if mode == "train":
        mha.train()
    else:
        mha.eval()
    
    batch_size = 32
    seq_len = 2000
    
    # Create inputs
    query = torch.randn(batch_size, seq_len, 256, device=device)
    key = query
    value = query
    key_padding_mask = torch.zeros(batch_size, seq_len, device=device, dtype=torch.bool)
    attn_mask = torch.zeros(seq_len, seq_len, device=device, dtype=torch.bool)
    
    print(f"Mode: {mode} (single MHA)")
    with torch.no_grad():
        output, _ = mha(query, key, value, 
                    key_padding_mask=key_padding_mask,
                    attn_mask=attn_mask,
                    need_weights=False)
    
    torch.cuda.synchronize()
    peak = torch.cuda.max_memory_allocated() / 1024**2
    print(f"Peak memory: {peak:.2f} MB")
    
    del mha, output
    torch.cuda.empty_cache()
    gc.collect()
    
    return peak

if __name__ == "__main__":
    print("Testing with hooks and direct MHA\n")

    print("\n" + "=" * 60)
    print("Single MultiheadAttention")
    print("=" * 60)
    train_mha = test_mha_directly("train")
    eval_mha = test_mha_directly("eval")
    
    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)
    print(f"Single MHA - Train:        {train_mha:.2f} MB")
    print(f"Single MHA - Eval:         {eval_mha:.2f} MB")

Running on an RTX 4090 with CUDA 12.7. Tried Python 3.12 + PyTorch 2.8.0 and Python 3.14 + PyTorch 2.9.0. Also tried running on a Google Colab instance. With all three setups, the issue persists.

It’s not clear to me why the fast_path would be so much less efficient than the “regular” path. It’s not just less memory efficient, it’s also much slower.