Error: checkpointing is not compatible with .grad(), please use .backward() if possible

hi,
I implement a rpc distributed model like the tutorial example;

I run my model on one gpu and use checkpointing to reduce the usage of my gpu. After I get the loss, I do the backward using torch.distributed.autograd.backward(context_id, [loss]), then the error above occurs.

I don’t use torch.distributed.autograd.grad() anywhere, so I wonder if torch.distributed.autograd.backward() calls grad() internally? If not, I should have checked it again.

Thanks.

Hi, yes distributed uses grad() internally.
There are plans to make checkpointing compatible with distribute here: New utils.checkpoint rollout. · Issue #65537 · pytorch/pytorch · GitHub