Hi all,
I am trying on a GRU model with 3 layers and 3840 hidden units which generate 3 outputs (namely output_max, output_min, output_max_first), due to the size of model I also implemented mixed precision (with snapshot to understand the memory usage) as follows:
for b in range(0, len(train_x), batch_size):
input = torch.tensor(train_x[b:(b+batch_size)], dtype=torch.float, requires_grad = False)
target = torch.tensor(train_y[b:(b+batch_size)], dtype=torch.float, requires_grad = False)
optimizer.zero_grad()
print("Memory Check 1", torch.cuda.memory_reserved())
with amp.autocast():
output_max, output_min, output_max_first = my_gru(input)
max_loss = loss_weight[0] * max_loss_func(output_max, target[:,0])
min_loss = loss_weight[1] * min_loss_func(output_min, target[:,1])
max_first_loss = loss_weight[2] * classification_loss_func(output_max_first, target[:,2])
print("Memory Check 2", torch.cuda.memory_reserved())
scaler.scale(max_loss).backward(retain_graph = True)
scaler.scale(min_loss).backward(retain_graph = True)
scaler.scale(max_first_loss).backward()
scaler.unscale_(optimizer)
print("Memory Check 3", torch.cuda.memory_reserved())
scaler.step(optimizer)
scaler.update()
The printed result are as follows:
Memory Check 1 1803550720
Memory Check 2 2732589056
Memory Check 3 4586471424
However, without mixed precision the above code generate the following result:
Memory Check 1 1803550720
Memory Check 2 2732589056
Memory Check 3 3659530240
Which has same memory usage during forward pass (which is weird) but less memory usage during backward pass.
Moreover, the memory usage seems to be carried forward to next loop (i.e. for the mixed precision implmentation, the Memory Check 1 of next loop returns 4586471424)
Did I somehow implement the mixed precision wrongly?
Many thanks for your help!