I am tying to use SWA with my custom dataloader but I have a doubt. This is my code:
swa_model.train() for indx, batch in enumerate(train_loader): image = batch["image"].type(torch.float).cuda() _ = swa_model(image)
But this runs out of memory fast. If I encapsulate in torch.no_grad() runs without problems… Maybe gradients are not cleaned or something but I don’t have clear how to do it properly for SWA model.
“If your dataloader has a different structure, you can update the batch normalization statistics of the swa_model by doing a forward pass with the swa_model on each element of the dataset.”