Use saved_tensors_hooks seem can't move the tensor from gpu to cpu

I want to implement checkpoint use offload instead of recomputing to save GPU memory.
The code is like this(using pytorch stable version 2.0.1):

def pack_hook(x: torch.Tensor):
    device = x.device
    x = x.detach().cpu()
    return (device, x)
        
def unpack_hook(t):
    device, data = t
    return data.to(device)
    
for epoch in range(0, 100):
	with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
		output = model.forward(input)
	output.backward()
	torch.cuda.empty_cache()

When the first epoch was done, the GPU memory didn’t reduce. So when running the second epoch, raise CUDA OOM.

How did you check the memory usage? Did you compare the allocated, reserved, or both memory stats?

import torch


def pack_hook(x: torch.Tensor):
    device = x.device
    x = x.detach().cpu()
    return (device, x)


def unpack_hook(t):
    device, data = t
    return data.to(device)


torch.cuda.reset_peak_memory_stats()
# about 2G
a = torch.ones((64, 4096, 4096), dtype=torch.float16,
               device="cuda:1", requires_grad=True) * 2
b = torch.ones((64, 4096, 4096), dtype=torch.float16,
               device="cuda:1", requires_grad=True) * 3


def fn():
    y = a * b
    return y


print(f"max allocated {torch.cuda.max_memory_allocated(1)/1024/1024/1024}G")
print(f"max reserved {torch.cuda.max_memory_reserved(1)/1024/1024/1024}G")

torch.cuda.reset_peak_memory_stats()
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
    output = fn()
    print(f"max allocated {torch.cuda.max_memory_allocated(1)/1024/1024/1024}G")
    print(f"max reserved {torch.cuda.max_memory_reserved(1)/1024/1024/1024}G")
    torch.cuda.reset_peak_memory_stats()
    output.mean().backward()
    print(f"max allocated {torch.cuda.max_memory_allocated(1)/1024/1024/1024}G")
    print(f"max reserved {torch.cuda.max_memory_reserved(1)/1024/1024/1024}G")

print("ALL DONE.")

I used this code to test it, and below is the output:

max allocated 8.0G
max reserved 8.0G

max allocated 10.0G
max reserved 10.0G

max allocated 20.000000953674316G
max reserved 20.001953125G

ALL DONE.

and if I delete the code with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
The output is:

max allocated 8.0G
max reserved 8.0G
max allocated 10.0G
max reserved 10.0G
max allocated 16.000000953674316G
max reserved 16.001953125G
ALL DONE.

seem the pack and unpack hook function will increase the memory usage?