I was wondering if there was any docs on how to use SyncBatchNorm with SWA. I have a mobilenet pretrained model which I converted into SyncBatchnorm using:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
and then do the DDP stuff and then I tried to update batch stats at the end of the training using the uility fiunction like so:
It starts updating the batch stats and then nothing happens… literally nothing… no error, no exit… it just sits there…
When I remove syncbatchnorm and train and update stats it all works perfectly.
I browsed through the source of the above:
and I am not exactly sure what is causing this… but I suspect:
and that it is not able to find the module in the model? Totally seem clueless
I seem to get better loss with syncbatchnorm and I would like to update the batch stats and perform inference, but the fact that the updates are not working is blocking me
do you have a minimal but full script that one can run.
If it’s hanging, and not doing anything, it probably involves a bit more deeper debugging on what’s happening.
Thank you VERY much for your reply. Well when I read the source, I see this:
So, I assumed this was the correct way to do it and did it… It works, but now I have a different issue. I want to infer the model on CPU, so I normally do: