I have a need to update only the batch norm running mean and variance during the training process across GPUs. So want to leverage SyncBatchNorm for that. I don’t need the gradient for any of the parameters in the model (its parameters are updated from the online model). Is there a way to use SyncBatchNorm setting for that purpose only?
I have attempted to use DDP without backward, but that create sync issues. One thing I’m thinking is to create a fake 0 loss, and perform the backward as a work around. I wonder if there are better ideas.
After some attempts, dist._all_gather_base in SyncBatchNorm kept being timed out after some training time for the exponential moving average model that wrapped in DDP.
Thank you @kumpera for your reply. I’m calling the model with SyncBatchNorm under the torch.no_grad() context manager. The timeout happens after the training works fine for a while (not at a particular epoch/step). I will try to use the TORCH_DISTRIBUTED_DEBUG env and look into the details.
Do you mean I change that part of the code to fit to my needs?