RuntimeError: Expected to mark a variable ready only once

Hi there!

I am trying to run a code that works perfectly well without distributed training, but fails with following error messages with distributed training
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the forwardfunction. Please make sure model parameters are not shared across multiple concurrent forward-backward passes2) Reused parameters in multiple reentrant backward passes. For example, if you use multiplecheckpoint functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases yet.3) Incorrect unused parameter detection. The return value of the forward function is inspected by the distributed data parallel wrapper to figure out if any of the module's parameters went unused. For unused parameters, DDP would not expect gradients from then. However, if an unused parameter becomes part of the autograd graph at a later point in time (e.g., in a reentrant backward when using checkpoint), the gradient will show up unexpectedly. If all parameters in the model participate in the backward pass, you can disable unused parameter detection by passing the keyword argument find_unused_parameters=False to torch.nn.parallel.DistributedDataParallel.

I think I’ve identified the root cause of this issue, which is pytorch checkpoint in the image backbone. I found that after disabling checkpoint (with_cp=False) in the model, the code can run smoothly. But if checkpoint is enabled (which is the default value to save memory), the program throws the above error.

You may see the compIete code of the backbone model here. Below is a simple illustrative example.

    def forward(self, x, hw_shape):
        for block in self.blocks:
            if self.with_cp:
                # not the lastest pytorch
                # see https://github.com/pytorch/pytorch/issues/63394
                # x = checkpoint.checkpoint(block, x, hw_shape)
                x = checkpoint.checkpoint(lambda x: block(x, hw_shape), x)
            else:
                x = block(x, hw_shape)

I wonder if there is a way to fix it?