In PyTorch, I have used saved_for_backward
to save the input tensors of certain layers before running the forward of those layers in torch.no_grad
mode. My point is to recompute the intermediate data during the backward
of torch.autograd.Function
without any memory leak and predict the GPU memory usage during backward in a layer-by-layer manner. However, I observed that the GPU memory usage by torch.cuda.memory_allcoated
is much higher than I expected. What should be the cause? Or maybe someone could introduce me to the “formal” way of getting the GPU memory usage during the backward
of torch.autograd.Function
in a layer-by-layer manner? I also tried using APIs like register_full_backward_hook
but it made the situation even worse in that the result is much higher than my current method. (according to a recent issue from GitHub Issues of PyTorch, it seems that PyTorch memory leaks when self
is used inside these hooks.)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.length = length
ctx.save_for_backward(*args)
with torch.no_grad():
output_tensors = ctx.run_function(*args[:length])
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
input_tensors = ctx.saved_tensors[:ctx.length]
input_params = list(ctx.saved_tensors[ctx.length:])
input_tensors = [x.detach().requires_grad_(True) for x in input_tensors]
with torch.enable_grad():
output_tensors = ctx.run_function(*input_tensors)
input_grads = torch.autograd.grad(
output_tensors,
input_tensors + input_params,
output_grads,
allow_unused=False,
)
input_grads = (None,)*ctx.length + input_grads[ctx.length:]
del output_tensors
if ctx.run_function.is_enabled_logging():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
print(
'## VGG, segment re-forward, torch.cuda.memory_allocated(): {}MB'
.format(torch.cuda.memory_allocated()/(1024*1024))
)
return (None, None) + input_grads