SyncBatchNorm with SWA

Hi thre,

I was wondering if there was any docs on how to use SyncBatchNorm with SWA. I have a mobilenet pretrained model which I converted into SyncBatchnorm using:

    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

and then do the DDP stuff and then I tried to update batch stats at the end of the training using the uility fiunction like so:

                    gpu_device=torch.device('cuda:1')
                    swa_model=swa_model.to(gpu_device)
                    print('==>  UPDATING BATCH STATS')
                    torch.optim.swa_utils.update_bn(train_loader, swa_model,gpu_device)
                    print('==>  FINISHED UPDATING BATCH STATS')

It starts updating the batch stats and then nothing happens… literally nothing… no error, no exit… it just sits there…

When I remove syncbatchnorm and train and update stats it all works perfectly.

I browsed through the source of the above:

and I am not exactly sure what is causing this… but I suspect:

and that it is not able to find the module in the model? Totally seem clueless :sob:

I seem to get better loss with syncbatchnorm and I would like to update the batch stats and perform inference, but the fact that the updates are not working is blocking me :sob:

J

do you have a minimal but full script that one can run.
If it’s hanging, and not doing anything, it probably involves a bit more deeper debugging on what’s happening.

Hi @smth

Thank you VERY much for your reply. Well when I read the source, I see this:

So, I assumed this was the correct way to do it and did it… It works, but now I have a different issue. I want to infer the model on CPU, so I normally do:

                            traced_script_module = torch.jit.trace(swa_model, (data))
                            traced_script_module.save("./swa_model.tar")

It complains:

    raise ValueError("SyncBatchNorm expected input tensor to be on GPU")
ValueError: SyncBatchNorm expected input tensor to be on GPU

But I want to infer the model on cpu by jit tracing, and I am not sure how I can get around this issue :sob:

I presume it is impossible to jit trace with the model on cpu and data on gpu (due to dtype)