DPP training incompatibility with checkpoint and detach

I am using pytorch ddp to train my model. Turns out if I use ddp, then I can not use checkpoint or detach gradient. The incompatibility is a big problem, because these techniques are important for my use.

My model consists of two part roughly, a language model for generate representation, where weights are detached, another part of the model is trained with gradients.

the code of the language model:

if exists(config.msa_bert.msa_bert_config.model_weight) and not config.msa_bert.skip_load_msa_bert:
self.bert_model = load_pretrain(self.bert_model, config.msa_bert.msa_bert_config.model_weight)

if config.msa_bert.msa_bert_config.freeze:
print(’ frezze pretrained msa transformer’)
for param in self.bert_model.parameters():
param.detach_()
self.bert_model.eval()
Note in the other part of my model, there are recycles with detach.

for i in range(n_recycle):
msa_fea, pair_fea = self.feat_extractor(msa_fea, pair_fea)
msa_fea, pair_fea = msa_fea.detach_(), pair_fea.detach_()
When using ddp, I have to turn on the find_unused_parameters=True , otherwise a error would be raised: RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one.
Seems like if you have a model with detached params, you have to turn on this.
Here comes the problem, if I keep find_unused_parameters=True and enable checkpoint, an error would be raised because a variable is marked twice.
I conjecture that during forward, those detached parameters are marked as ready because of find_unused_parameters=True , and somehow they are marked ready again and causes this error.

I am wondering in what cases a param would be marked as ready again?
And, what does it means for a param to be marked as ready? I think it is something to do with the autograd and the gradient compute map.

I accidentally find a solution that turn off the recycle ( i.e., turn off detach) and checkpoint while keep find_unused_parameters=True , the ddp training works.
However, the problem is I can not turn off them as they are important for the efficiency. Without checkpoint, the gpu memory would explode.

Responded on the issue: DPP training incompatibility with checkpoint and detach · Issue #83074 · pytorch/pytorch · GitHub