GPU memory usage increased during mixed precision trainning

Hi all,

I am trying on a GRU model with 3 layers and 3840 hidden units which generate 3 outputs (namely output_max, output_min, output_max_first), due to the size of model I also implemented mixed precision (with snapshot to understand the memory usage) as follows:

for b in range(0, len(train_x), batch_size):
    
    input = torch.tensor(train_x[b:(b+batch_size)], dtype=torch.float, requires_grad = False)
    target = torch.tensor(train_y[b:(b+batch_size)], dtype=torch.float, requires_grad = False)
        
    optimizer.zero_grad()

    print("Memory Check 1", torch.cuda.memory_reserved())
    
    with amp.autocast():
        output_max, output_min, output_max_first = my_gru(input)
    
        max_loss = loss_weight[0] * max_loss_func(output_max, target[:,0])
        min_loss = loss_weight[1] * min_loss_func(output_min, target[:,1])
        max_first_loss = loss_weight[2] * classification_loss_func(output_max_first, target[:,2])
    
    print("Memory Check 2", torch.cuda.memory_reserved())

    scaler.scale(max_loss).backward(retain_graph = True)
    scaler.scale(min_loss).backward(retain_graph = True)
    scaler.scale(max_first_loss).backward()
    
    scaler.unscale_(optimizer)
    
    print("Memory Check 3", torch.cuda.memory_reserved())
    
    scaler.step(optimizer)
    scaler.update()

The printed result are as follows:

Memory Check 1 1803550720
Memory Check 2 2732589056
Memory Check 3 4586471424

However, without mixed precision the above code generate the following result:

Memory Check 1 1803550720
Memory Check 2 2732589056
Memory Check 3 3659530240

Which has same memory usage during forward pass (which is weird) but less memory usage during backward pass.

Moreover, the memory usage seems to be carried forward to next loop (i.e. for the mixed precision implmentation, the Memory Check 1 of next loop returns 4586471424)

Did I somehow implement the mixed precision wrongly?

Many thanks for your help!

torch.cuda.memory_reserved() will return the allocated and cached memory, which might be higher, e.g. if you are using cudnn.backend.benchmark=True (and cudnn would profile more kernels, if they are available for mixed-precision).
To check the allocated memory, you could use torch.cuda.memory_allocated().