Your use case sounds similar to CPU offloading, which uses torch.autograd.graph.saved_tensors_hooks
or torch.autograd.graph.save_on_cpu
if I’m not mistaken, so you could take a look at these context managers.
1 Like