I was wondering if there was any docs on how to use SyncBatchNorm with SWA. I have a mobilenet pretrained model which I converted into SyncBatchnorm using:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
and then do the DDP stuff and then I tried to update batch stats at the end of the training using the uility fiunction like so:
print('==> UPDATING BATCH STATS')
print('==> FINISHED UPDATING BATCH STATS')
It starts updating the batch stats and then nothing happens… literally nothing… no error, no exit… it just sits there…
When I remove syncbatchnorm and train and update stats it all works perfectly.
I browsed through the source of the above:
and I am not exactly sure what is causing this… but I suspect:
and that it is not able to find the module in the model? Totally seem clueless
I seem to get better loss with syncbatchnorm and I would like to update the batch stats and perform inference, but the fact that the updates are not working is blocking me
do you have a minimal but full script that one can run.
If it’s hanging, and not doing anything, it probably involves a bit more deeper debugging on what’s happening.
Thank you VERY much for your reply. Well when I read the source, I see this:
So, I assumed this was the correct way to do it and did it… It works, but now I have a different issue. I want to infer the model on CPU, so I normally do:
traced_script_module = torch.jit.trace(swa_model, (data))
raise ValueError("SyncBatchNorm expected input tensor to be on GPU")
ValueError: SyncBatchNorm expected input tensor to be on GPU
But I want to infer the model on cpu by jit tracing, and I am not sure how I can get around this issue
I presume it is impossible to jit trace with the model on cpu and data on gpu (due to dtype)