Exclude loss from some processes in backward()

I have set up my model with DistributedDataParallel, which is working well (i.e. it runs). The problem is that during early epochs some processes compute infinite/nan losses. Can I simply omit the call to backward() in those cases or will this confuse the synchronization that’s happening under the hood?

All calls to backward() should be independent across successive calls, but not independent with respect to different processes on the same call.

In order to arbitrarily skip backwards, you can use the no_sync context manager:
https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=no_sync#torch.nn.parallel.DistributedDataParallel.no_sync
However, be aware that this must be used across all ranks otherwise there will be synchronization issues. If you only need to skip backwards a few times early in the training, you could also consider having all processes agree to skip backward() or not:

should_skip_backwards = ... # each process computes this
should_skip_bwd_list = [None for _ in range(nranks)]
torch.distributed.all_gather_object(should_skip_bwd_list, should_skip_backwards)

globally_skip_bwd = any(should_skip_bwd_list) # all ranks agree