Mixed precision increases memory in meta-learning?

In meta-learning you want to differentiate through (inner) gradient updates themselves, for example to get the (outer) gradient of the validation loss wrt some hyperparameter.

I had issues with my outer gradients being nan in mixed precision (regardless of the loss scaler value) so I made a toy example. I can’t reproduce the nan outer gradient with it, but it exposes another issue, namely memory consumption being larger in AMD mode:

import torch
# torch.backends.cudnn.benchmark=True

DEVICE='cuda'
N, D = 80000, 800 #may need myltiple of 8
AMP=True

def get_data(N, D):
    x_train = torch.randn((N, D), device=DEVICE)
    x_val = torch.randn((N, D), device=DEVICE)
    true_weights = torch.randn(D, device=DEVICE)
    y_train = torch.matmul(x_train, true_weights)+ torch.randn(N, device=DEVICE)*0.05
    y_val = torch.matmul(x_val, true_weights)+ torch.randn(N, device=DEVICE)*0.05
    return x_train, y_train, x_val, y_val

torch.manual_seed(0)
torch.cuda.manual_seed(0)
# with torch.cuda.amp.autocast(enabled=AMP):
x_train, y_train, x_val, y_val = get_data(N, D)
weights = torch.randn(D, requires_grad=True, device=DEVICE) # inner param to learn
regularizer = torch.full((D,), 5e-4, requires_grad=True, device=DEVICE) # outer param to learn
SCALER = 2**8 if AMP else 1

## Inner loop: learn weights
for i in range(10):
    with torch.cuda.amp.autocast(enabled=AMP):
        y_pred_train = torch.matmul(x_train, weights)
        train_loss = torch.mean((y_train-y_pred_train)**2)
        # with torch.cuda.amp.autocast(enabled=False):
        inner_grads = torch.autograd.grad(SCALER*train_loss, weights, create_graph=True)[0]
        inner_grads = (1/SCALER)*inner_grads
        weights = weights - 0.1*inner_grads - regularizer*weights
    # print(f'train loss {train_loss:.3g} -- inner grads min {torch.min(inner_grads):.3f} max {torch.max(inner_grads):.3f}')
    print(f'memory allocated {float(torch.cuda.memory_allocated()) / (1024**3):.3g} GB')

## Outer step: learn regularizer
with torch.cuda.amp.autocast(enabled=AMP):
    y_pred_val = torch.matmul(x_val, weights)
    val_loss = SCALER * torch.mean((y_val - y_pred_val) ** 2)

val_loss.backward()
# print(f'---> outer grads {regularizer.grad*(1/SCALER)}')

When I run this with AMP=False I get:

memory allocated 0.478 GB
memory allocated 0.479 GB
memory allocated 0.48 GB
memory allocated 0.48 GB
memory allocated 0.481 GB
memory allocated 0.481 GB
memory allocated 0.482 GB
memory allocated 0.483 GB
memory allocated 0.483 GB
memory allocated 0.484 GB

but when I run with AMP=True I get

memory allocated 0.597 GB
memory allocated 0.717 GB
memory allocated 0.837 GB
memory allocated 0.957 GB
memory allocated 1.08 GB
memory allocated 1.2 GB
memory allocated 1.32 GB
memory allocated 1.44 GB
memory allocated 1.56 GB
memory allocated 1.68 GB

What is happening here?

System:
Windows 10
pytorch 1.7.1
Cuda 11.0
Cudnn8.0
Python 3.8
RTX 3070 laptop version

same behavior observed for RTX 2080 (also CUDA 11.0). Should I report this as a bug on Github @ptrblck_de ?

FYI this behavior may be specific to torch.matmul since replacing forward prop with some CNN leads to smaller memory allocated when AMP=True, which what we expect. Unfortunately the error between AMD=True and AMD=False gradients is much larger in the case of CNNs…

Could you update to the latest stable release or the nightly and rerun the test?
We hit a caching issue some time ago for linear layers (and fixed it). Based on the output you are seeing, I don’t think the problem is that AMP uses more memory in particular, but that the memory usage is clearly increasing/leaking.

I get the same behavior after upgrading to Pytorch 1.8.0. I’ve also observed AMP having a higher memory cost in non meta-learning settings but I think that was due to not setting the entire forward pass in autocast. Could that be the case here somehow?

It’s disappointing because meta-learning is expensive in memory and AMP would help a lot of folks in the field. Although I suspect that even after this memory issue is sorted, the (outer) gradient calculation will still be inaccurate in this case, because the graph is “chaotic” and gradient degradation is often an issue even in FP32.

report as bug @ptrblck ?

@mcarilli could you check the initial code snippet regarding the memory usage?

I don’t understand your question, sorry. I already checked with the latest stable pytorch release (1.8.0). The snippet is self contained and you should be able to run it as is. Do you not reproduce the memory leak?

Sorry, the question was for a colleague (the original author of the mixed-precision training utility) to take a look at this issue, if he has time. :wink:

Sorry I read your message too fast! :smile:

(FYI, unrelated to memory usage, you don’t need to set a manual SCALER value. torch.cuda.amp.GradScaler automatically and dynamically chooses the scale factor. You probably know that, but you may not know it can be used in a double-backward setting. See the gradient penalty example. Or maybe you knew that too and avoided it in your example for simplicity.)

I have good news and bad news.

Bad news: I think all parts under the hood are working as intended. I think the “leak” you observe happens because x_train is big and (with amp enabled) casted every iteration. x_train is casted to a new FP16 tensor on entrance to each matmul, and the casted copy is stashed for backward (by autograd, not amp). The autograd history of ops involving weights is retained across the 10 inner iterations, by design of the algorithm, which means all 10 FP16 copies of x_train are also retained.

Within any (outermost) invocation of with autocast, the backend caches some casts to streamline tensor reuse. But only tensors that require grad are eligible (the idea was to cache casts for model params). For the exact code you posted, I could relax the “requires grad” criterion, then if you ran all 10 iterations under a single with autocast (as opposed to entering and exiting every iteration) and x_train would only be casted once, and hit in cache for later iterations. But for a real training script x_train (i assume) will be a new tensor every iteration. Telling autocast it may cache the casts of each x_train won’t help, they’ll be distinct and separately stashed every iteration regardless.

Good news: The “leakage” difference you see between amp and non-amp is purely because x_train is static across iterations, so with non-amp, matmul uses and stashes the same x_train for backward each iteration. If I change the script slightly to be more realistic, generating x_train randomly each of the 10 iterations (below), AMP=False shows the same leaky behavior. In fact it’s even worse, ending up with 2.87 GB on my machine.

So I’m calling this not a bug. (Hopefully imitating the gradient penalty pattern with GradScaler helps mitigate nans, note that it exits the autocast context before the call to autograd.grad, although that doesn’t make a difference in your case.)

# https://discuss.pytorch.org/t/mixed-precision-increases-memory-in-meta-learning/115608
import torch
# torch.backends.cudnn.benchmark=True

DEVICE='cuda'
N, D = 80000, 800 #may need myltiple of 8
# AMP=True
AMP=False

def get_data(N, D):
    x_train = torch.randn((N, D), device=DEVICE)
    x_val = torch.randn((N, D), device=DEVICE)
    true_weights = torch.randn(D, device=DEVICE)
    y_train = torch.matmul(x_train, true_weights)+ torch.randn(N, device=DEVICE)*0.05
    y_val = torch.matmul(x_val, true_weights)+ torch.randn(N, device=DEVICE)*0.05
    return x_train, y_train, x_val, y_val

torch.manual_seed(0)
torch.cuda.manual_seed(0)
# with torch.cuda.amp.autocast(enabled=AMP):
x_train, y_train, x_val, y_val = get_data(N, D)
weights = torch.randn(D, requires_grad=True, device=DEVICE) # inner param to learn
regularizer = torch.full((D,), 5e-4, requires_grad=True, device=DEVICE) # outer param to learn
SCALER = 2**8 if AMP else 1

## Inner loop: learn weights
for i in range(10):
    with torch.cuda.amp.autocast(enabled=AMP):
        y_pred_train = torch.matmul(x_train, weights)
        train_loss = torch.mean((y_train-y_pred_train)**2)
        # with torch.cuda.amp.autocast(enabled=False):
        inner_grads = torch.autograd.grad(SCALER*train_loss, weights, create_graph=True)[0]
        inner_grads = (1/SCALER)*inner_grads
        weights = weights - 0.1*inner_grads - regularizer*weights
    x_train = torch.randn((N, D), device=DEVICE)
    # print(f'train loss {train_loss:.3g} -- inner grads min {torch.min(inner_grads):.3f} max {torch.max(inner_grads):.3f}')
    print(f'memory allocated {float(torch.cuda.memory_allocated()) / (1024**3):.3g} GB')

## Outer step: learn regularizer
with torch.cuda.amp.autocast(enabled=AMP):
    y_pred_val = torch.matmul(x_val, weights)
    val_loss = SCALER * torch.mean((y_val - y_pred_val) ** 2)

val_loss.backward()
# print(f'---> outer grads {regularizer.grad*(1/SCALER)}')

2 Likes

Thanks a lot for looking into this. I can confirm that the realistic setting with different xs leads to more memory use for AMP=False as expected. I am considering this issue resolved :slight_smile:

PS: About manual scaler values. The issue is that scaler.step() method requires an optimizer as input. In my case, I cannot use Pytorch optimizers because I need weights at step t to be a function of the weights at step t-1, t-2, ..., 0, in order to backprop through the whole inner loop. Furthermore, AFAIK GradScaler simply decreases the scaler value when nans occur, but I grid-searched a large range of scaler values and always get nans in my code, so I don’t think GradScaler would help in my case. Unless nans can occur for reasons I haven’t considered I’m guessing some computational graphs are just too deep/chaotic for AMP.