Modifying forward/backward pass

Your use case sounds similar to CPU offloading, which uses torch.autograd.graph.saved_tensors_hooks or torch.autograd.graph.save_on_cpu if I’m not mistaken, so you could take a look at these context managers.

1 Like