Use torch.nn.SyncBatchNorm only for running mean variance update without DDP backward

I have a need to update only the batch norm running mean and variance during the training process across GPUs. So want to leverage SyncBatchNorm for that. I don’t need the gradient for any of the parameters in the model (its parameters are updated from the online model). Is there a way to use SyncBatchNorm setting for that purpose only?

I have attempted to use DDP without backward, but that create sync issues. One thing I’m thinking is to create a fake 0 loss, and perform the backward as a work around. I wonder if there are better ideas.

After some attempts, dist._all_gather_base in SyncBatchNorm kept being timed out after some training time for the exponential moving average model that wrapped in DDP.

Anyone have any suggestions on what to look into?

Try setting the TORCH_DISTRIBUTED_DEBUG env var to DETAIL and look at the log files for clues.

More detail here: Distributed communication package - torch.distributed — PyTorch master documentation

Timeouts usually happens when collectives are issues out of order or at least one rank is not participating.

Have you tried to call SyncBatchNorm under the torch.no_grad() context manager? IE:

You can pick the code from BatchNorm and extract the sync_batch_norm path for your own scenario: pytorch/ at master · pytorch/pytorch · GitHub

Thank you @kumpera for your reply. I’m calling the model with SyncBatchNorm under the torch.no_grad() context manager. The timeout happens after the training works fine for a while (not at a particular epoch/step). I will try to use the TORCH_DISTRIBUTED_DEBUG env and look into the details.

Do you mean I change that part of the code to fit to my needs?

Yes, that code imposes some restrictions that might not work for you.

In my case that sync is timed out. SyncBatchNorm works as expected in the forward pass. The issue in my case is that I have multiple data loaders and each with multiple workers, they turn out can create a deadlock situation. Once I reduce the num_workers to 1, issue is resolved.

1 Like