SWA Manual Update OOM

Hi guys!

I am tying to use SWA with my custom dataloader but I have a doubt. This is my code:

swa_model.train()
for indx, batch in enumerate(train_loader):
    image = batch["image"].type(torch.float).cuda()
    _ = swa_model(image)

But this runs out of memory fast. If I encapsulate in torch.no_grad() runs without problems… Maybe gradients are not cleaned or something but I don’t have clear how to do it properly for SWA model.

“If your dataloader has a different structure, you can update the batch normalization statistics of the swa_model by doing a forward pass with the swa_model on each element of the dataset.”

Using torch.no_grad() updates statistics:

swa_model.train()
with torch.no_grad()
    for indx, batch in enumerate(train_loader):
        image = batch["image"].type(torch.float).cuda()
        _ = swa_model(image)

Check with:

for module in s

wa_model.modules():
    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
        print(module.running_mean)
        print(module.running_var)
        print(module.momentum)
        break