nn.TransformerEncoder OOM with no_grad + eval + mask

nn.TransfomerEncoder uses significantly more memory and OOM’s in a no_grad + eval scenario. I understand memory + speed trade offs are made some for the fast paths in the codebase, however I think this is exessive if it is to the point where OOM occurs in evaluation and not in training with autograd running. Once the no_grad + mask + eval trifecta occurs, 3GB usage ballons to OOM’ing an RTX 3090 (the next request is +8GiB at OOM). If you subtract any one of these conditions, the issue does not occur.

torch version: 2.4.0+cu124
GPU: RTX 3090

# Minimum reproduceable example
import torch
from torch import nn


def make_causal_mask(time: int, n_tokens: int, device: torch.device | None):
    """Make causal mask where multiple tokens are in each timestep"""
    mask = (
        torch.arange(time * n_tokens, device=device)
        .unsqueeze(0)
        .repeat(time * n_tokens, 1)
        // n_tokens
    )
    indices = torch.arange(time * n_tokens, device=device).unsqueeze(-1) // n_tokens
    final = mask > indices  # True is not allowed to attend
    return final


mod = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(128, 8, batch_first=True), 2
).cuda()
data = torch.randn(16, 4096, 128).cuda()
T = 512
N = 8
mask = make_causal_mask(T, N, device=data.device)

res = mod(data, mask=mask)
torch.cuda.synchronize()
print(f"training: {torch.cuda.memory_allocated() // 1e9} GB")

del res
torch.cuda.empty_cache()

with torch.no_grad():
    mod = mod.eval()
    res = mod(data, mask=mask)
    print(f"eval+no_grad+mask: {torch.cuda.memory_allocated() // 1e9} GB")