I have a need to update only the batch norm running mean and variance during the training process across GPUs. So want to leverage SyncBatchNorm for that. I don’t need the gradient for any of the parameters in the model (its parameters are updated from the online model). Is there a way to use SyncBatchNorm setting for that purpose only?
I have attempted to use DDP without backward, but that create sync issues. One thing I’m thinking is to create a fake 0 loss, and perform the backward as a work around. I wonder if there are better ideas.
After some attempts, dist._all_gather_base in SyncBatchNorm kept being timed out after some training time for the exponential moving average model that wrapped in DDP.
Thank you @kumpera for your reply. I’m calling the model with SyncBatchNorm under the torch.no_grad() context manager. The timeout happens after the training works fine for a while (not at a particular epoch/step). I will try to use the TORCH_DISTRIBUTED_DEBUG env and look into the details.
Do you mean I change that part of the code to fit to my needs?
In my case that sync is timed out. SyncBatchNorm works as expected in the forward pass. The issue in my case is that I have multiple data loaders and each with multiple workers, they turn out can create a deadlock situation. Once I reduce the num_workers to 1, issue is resolved.