Pytorch distributed loss.backward errors out with CudnnBatchNormBackward, inplace operation

I have some simple model code, transfer learned from resnet, and when I run it without distributed, everything works fine. however, when I try it in ditributed mode, I get this weird error:

Error detected in CudnnBatchNormBackward.

and then:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2048]] is at version 4; expected version 3 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

The model in question is a vanilla resnet, which is loaded like so:

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        backbone = models.resnet50(pretrained=True)
        self.backbone = torch.nn.SyncBatchNorm.convert_sync_batchnorm(backbone)

    def forward(self, x):
        y1 = self.backbone(x)
        y_n = F.normalize(z1, dim=-1)
        return y_n

and the training loop looks like so:

    for batch_idx, ((img1, img2), _) in enumerate(train_loader):

        if args.gpu is not None:
            img1 = img1.cuda(args.gpu, non_blocking=True)
            img2 = img2.cuda(args.gpu, non_blocking=True)

        optimizer.zero_grad()

        out_1 = model(img1)
        out_2 = model(img2)

        loss = contrastive_loss_fn(out1,out2)

        loss.backward()

        optimizer.step()

and it errors out at loss.backward in essence. I suspect it is because I run the model twice, but I am unsure of this.

Any pointers would be great… Havent been able to troubleshoot this for days now!

PS: I have tried cloning the outputs and using SyncedBatchnorm… neither seems to help!

I can suggest to set broadcast_buffers=False in the DDP module constructor as was mentioned here

1 Like

I spent about 6 days hunting for this solution :frowning:
Thank you VERY much @pbelevich genius stuff!
|I just have a follow up question (it all works fine now with the buffer flag): what does this do to performance? are there any gotchas I need to be aware of?
Thank you.