Increased memory usage with AMP

Hi. I’m benchmarking automatic mixed precision vs. default mode (float32). I’m getting a speed-up but, the memory usage is the same, if not higher. I’m running Pytorch’s tutorial with minimal change:

import torch, time, gc
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Switching between two
use_amp = True
# use_amp = False

start_time = None

def start_timer():
    global start_time
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.synchronize()
    start_time = time.time()

def end_timer_and_print(local_msg):
    torch.cuda.synchronize()
    end_time = time.time()
    print("\n" + local_msg)
    print("Total execution time = {:.3f} sec".format(end_time - start_time))
    print(f'Memory allocated {torch.cuda.memory_allocated() // (1024**2)} MB')
    print(f'Max memory allocated {torch.cuda.max_memory_allocated() // (1024**2)} MB')
    print(f'Memory reserved {torch.cuda.memory_reserved() // (1024**2)} MB')
    print(f'Max memory reserved {torch.cuda.max_memory_reserved() // (1024**2)} MB')

def make_model(in_size, out_size, num_layers):
    layers = []
    for _ in range(num_layers - 1):
        layers.append(torch.nn.Linear(in_size, in_size))
        layers.append(torch.nn.ReLU())
    layers.append(torch.nn.Linear(in_size, out_size))
    return torch.nn.Sequential(*tuple(layers)).cuda()

epochs = 1
num_batches = 50
batch_size = 512 # Try, for example, 128, 256, 513.

in_size = 4096
out_size = 4096
num_layers = 16

# in_size = 8192
# out_size = 8192
# num_layers = 32

data = [torch.randn(batch_size, in_size, device="cuda") for _ in range(num_batches)]
targets = [torch.randn(batch_size, out_size, device="cuda") for _ in range(num_batches)]

loss_fn = torch.nn.MSELoss().cuda()

net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr=0.001)

scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

start_timer()
for epoch in range(epochs):
    for input, target in zip(data, targets):
        with torch.cuda.amp.autocast(enabled=use_amp):
            output = net(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        opt.zero_grad() # set_to_none=True here can modestly improve performance
message = "Mixed precision:" if use_amp else "Default precision:"
end_timer_and_print(message)

Outputs:

Default precision:
Total execution time = 3.553 sec
Memory allocated 2856 MB
Max memory allocated 3176 MB
Memory reserved 3454 MB
Max memory reserved 3454 MB
# nvidia-smi shows 4900 MB

Mixed precision:
Total execution time = 1.652 sec
Memory allocated 2852 MB
Max memory allocated 3520 MB
Memory reserved 3646 MB
Max memory reserved 3646 MB
# nvidia-smi shows 5092

When I try to saturate the GPU (RTX 6000 with 24 GB memory) using different hyperparameters, default mode works, but AMP goes out of memory:

in_size = 4096
out_size = 4096
num_layers = 16

Outputs:

Default precision:
Total execution time = 29.503 sec
Memory allocated 18002 MB
Max memory allocated 19282 MB
Memory reserved 19284 MB
Max memory reserved 19284 MB
# nvidia-smi shows 20730 MB

# Mixed precision goes out of memory:
RuntimeError: CUDA out of memory. Tried to allocate 256.00 MiB (GPU 0; 23.65 GiB total capacity; 22.08 GiB already allocated; 161.44 MiB free; 22.08 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Torch version: 1.10.0.dev20210630+cu113
CUDA version: 11.3
GPU model: NVIDIA Quadro RTX 6000

Thanks in advance!