I have code written as follows
if self.use_checkpoint:
out = torch.utils.checkpoint.checkpoint(upscale_layer, x_out.permute(0,3,1,2).contiguous())
else:
out = upscale_layer(x_out.permute(0,3,1,2).contiguous())
when all the parameters before this line are frozen, using gradient checkpointing will lead to the following error:
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument
find_unused_parameters=Trueto
torch.nn.parallel.DistributedDataParallel, and by ...
If use_checkpoint is false, no problem occurs.
Why? Is this expected?