Mixed precision increases memory in meta-learning?

(FYI, unrelated to memory usage, you don’t need to set a manual SCALER value. torch.cuda.amp.GradScaler automatically and dynamically chooses the scale factor. You probably know that, but you may not know it can be used in a double-backward setting. See the gradient penalty example. Or maybe you knew that too and avoided it in your example for simplicity.)

I have good news and bad news.

Bad news: I think all parts under the hood are working as intended. I think the “leak” you observe happens because x_train is big and (with amp enabled) casted every iteration. x_train is casted to a new FP16 tensor on entrance to each matmul, and the casted copy is stashed for backward (by autograd, not amp). The autograd history of ops involving weights is retained across the 10 inner iterations, by design of the algorithm, which means all 10 FP16 copies of x_train are also retained.

Within any (outermost) invocation of with autocast, the backend caches some casts to streamline tensor reuse. But only tensors that require grad are eligible (the idea was to cache casts for model params). For the exact code you posted, I could relax the “requires grad” criterion, then if you ran all 10 iterations under a single with autocast (as opposed to entering and exiting every iteration) and x_train would only be casted once, and hit in cache for later iterations. But for a real training script x_train (i assume) will be a new tensor every iteration. Telling autocast it may cache the casts of each x_train won’t help, they’ll be distinct and separately stashed every iteration regardless.

Good news: The “leakage” difference you see between amp and non-amp is purely because x_train is static across iterations, so with non-amp, matmul uses and stashes the same x_train for backward each iteration. If I change the script slightly to be more realistic, generating x_train randomly each of the 10 iterations (below), AMP=False shows the same leaky behavior. In fact it’s even worse, ending up with 2.87 GB on my machine.

So I’m calling this not a bug. (Hopefully imitating the gradient penalty pattern with GradScaler helps mitigate nans, note that it exits the autocast context before the call to autograd.grad, although that doesn’t make a difference in your case.)

# https://discuss.pytorch.org/t/mixed-precision-increases-memory-in-meta-learning/115608
import torch
# torch.backends.cudnn.benchmark=True

DEVICE='cuda'
N, D = 80000, 800 #may need myltiple of 8
# AMP=True
AMP=False

def get_data(N, D):
    x_train = torch.randn((N, D), device=DEVICE)
    x_val = torch.randn((N, D), device=DEVICE)
    true_weights = torch.randn(D, device=DEVICE)
    y_train = torch.matmul(x_train, true_weights)+ torch.randn(N, device=DEVICE)*0.05
    y_val = torch.matmul(x_val, true_weights)+ torch.randn(N, device=DEVICE)*0.05
    return x_train, y_train, x_val, y_val

torch.manual_seed(0)
torch.cuda.manual_seed(0)
# with torch.cuda.amp.autocast(enabled=AMP):
x_train, y_train, x_val, y_val = get_data(N, D)
weights = torch.randn(D, requires_grad=True, device=DEVICE) # inner param to learn
regularizer = torch.full((D,), 5e-4, requires_grad=True, device=DEVICE) # outer param to learn
SCALER = 2**8 if AMP else 1

## Inner loop: learn weights
for i in range(10):
    with torch.cuda.amp.autocast(enabled=AMP):
        y_pred_train = torch.matmul(x_train, weights)
        train_loss = torch.mean((y_train-y_pred_train)**2)
        # with torch.cuda.amp.autocast(enabled=False):
        inner_grads = torch.autograd.grad(SCALER*train_loss, weights, create_graph=True)[0]
        inner_grads = (1/SCALER)*inner_grads
        weights = weights - 0.1*inner_grads - regularizer*weights
    x_train = torch.randn((N, D), device=DEVICE)
    # print(f'train loss {train_loss:.3g} -- inner grads min {torch.min(inner_grads):.3f} max {torch.max(inner_grads):.3f}')
    print(f'memory allocated {float(torch.cuda.memory_allocated()) / (1024**3):.3g} GB')

## Outer step: learn regularizer
with torch.cuda.amp.autocast(enabled=AMP):
    y_pred_val = torch.matmul(x_val, weights)
    val_loss = SCALER * torch.mean((y_val - y_pred_val) ** 2)

val_loss.backward()
# print(f'---> outer grads {regularizer.grad*(1/SCALER)}')

3 Likes