Gradient checkpointing causes some parameters to not receive grad error?

I have code written as follows


if self.use_checkpoint:
    out = torch.utils.checkpoint.checkpoint(upscale_layer, x_out.permute(0,3,1,2).contiguous())
else:
    out = upscale_layer(x_out.permute(0,3,1,2).contiguous())

when all the parameters before this line are frozen, using gradient checkpointing will lead to the following error:

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument find_unused_parameters=Truetotorch.nn.parallel.DistributedDataParallel, and by ...

If use_checkpoint is false, no problem occurs.

Why? Is this expected?