nn.TransfomerEncoder uses significantly more memory and OOM’s in a no_grad + eval scenario. I understand memory + speed trade offs are made some for the fast paths in the codebase, however I think this is exessive if it is to the point where OOM occurs in evaluation and not in training with autograd running. Once the no_grad + mask + eval trifecta occurs, 3GB usage ballons to OOM’ing an RTX 3090 (the next request is +8GiB at OOM). If you subtract any one of these conditions, the issue does not occur.
torch version: 2.4.0+cu124
GPU: RTX 3090
# Minimum reproduceable example
import torch
from torch import nn
def make_causal_mask(time: int, n_tokens: int, device: torch.device | None):
"""Make causal mask where multiple tokens are in each timestep"""
mask = (
torch.arange(time * n_tokens, device=device)
.unsqueeze(0)
.repeat(time * n_tokens, 1)
// n_tokens
)
indices = torch.arange(time * n_tokens, device=device).unsqueeze(-1) // n_tokens
final = mask > indices # True is not allowed to attend
return final
mod = nn.TransformerEncoder(
nn.TransformerEncoderLayer(128, 8, batch_first=True), 2
).cuda()
data = torch.randn(16, 4096, 128).cuda()
T = 512
N = 8
mask = make_causal_mask(T, N, device=data.device)
res = mod(data, mask=mask)
torch.cuda.synchronize()
print(f"training: {torch.cuda.memory_allocated() // 1e9} GB")
del res
torch.cuda.empty_cache()
with torch.no_grad():
mod = mod.eval()
res = mod(data, mask=mask)
print(f"eval+no_grad+mask: {torch.cuda.memory_allocated() // 1e9} GB")