Gradient checkpointing, transfert beetween gpu and cpu

I’m trying to transfer memory between GPU and CPU with gradient checkpointing, i use this code

In my forward function i tried to use:

tensor = args[0]
tensor = tensor.cpu()
tensor = tensor.cuda()

When i transfer the memory to the cpu, _grad_fn of tensor is set to None.
Does anyone know how to avoid this?


Within the first forward call during the checkpoint, the gradients are not tracked, so any operation will set the grad_fn to None. You can check whether or not the gradients are tracked with torch.is_grad_enabled().