DistributedDataParallel: Gradient computation has been modifed by an inplace operation

Hi,

I try to get DistributedDataParallel working for my application but I run into the following exception when I want start the training. Its happening if the backward() function of the loss is called:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1024]] is at version 5; expected version 4 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

This is my model:

   class Model(nn.Module):
        def __init__(self, output_channels: int, dropout: float = 0.5):
            super().__init__()
            self.conv_layer0 = self._make_conv_layer(1, 64)
            self.conv_layer1 = self._make_conv_layer(64, 128)
            self.conv_layer2 = self._make_conv_layer(128, 256)
            self.conv_layer3 = self._make_conv_layer(256, 512)
            self.conv_layer4 = self._make_conv_layer(512, 1024)
            self.max_pooling = nn.MaxPool3d((2, 2, 2))
            self.headnet = self._make_headnet(2 * 2 * 2 * 1024, 2048, output_channels, dropout=dropout)

        @staticmethod
        def _make_conv_layer(in_c: int, out_c: int):
            conv_layer = nn.Sequential(
                nn.Conv3d(in_c, out_c, kernel_size=3),
                nn.BatchNorm3d(out_c),
                nn.LeakyReLU(),
                nn.Conv3d(out_c, out_c, kernel_size=3),
                nn.BatchNorm3d(out_c),
                nn.LeakyReLU(),
            )
            return conv_layer

        @staticmethod
        def _make_headnet(in_c1: int, out_c1: int, out_head: int, dropout: float) -> nn.Sequential:
            headnet = nn.Sequential(
                nn.Dropout(p=dropout),
                nn.Linear(in_c1, out_c1),
                nn.LeakyReLU(),
                nn.Linear(out_c1, out_c1),
                nn.LeakyReLU(),
                nn.Linear(out_c1, out_head),
            )
            return headnet

        def forward(self, inputtensor):
            """
            Forward pass through the network
            :param inputtensor: Input tensor
            """
            out = self.conv_layer0(inputtensor)
            out = self.conv_layer1(out)
            out = self.conv_layer2(out)
            out = self.max_pooling(out)
            out = self.conv_layer3(out)
            out = self.conv_layer4(out)
            out = self.max_pooling(out)
            out = out.reshape(out.size(0), -1)  # flatten
            out = self.headnet(out)
            out = F.normalize(out, p=2, dim=1)

            return out

I can’t see any inplace operation so I’m running out of ideas. I would be happy if someone could point me into a direction to look for.

The model was running fine with DataParallel.

Best,
Thorsten

I just found out that the training starts if I set broadcast_buffers=False in DistributedDataParallel, but I’m not sure what the option does…

Related issue: Distributed: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2048]] is at version 4; expected version 3 instead · Issue #62474 · pytorch/pytorch · GitHub

May be related to the BatchNorm layers and replacing those with SyncBatchNorm — PyTorch 1.9.1 documentation would help

Hi Howard,

thanks for helping me out!

Replacing nn.BatchNorm3D(out_channels) with nn.SyncBatchNorm(out_channels) worked.

I noticed that there is no nn.SyncBatchNorm3D? Do I need to take any extra measures here?

Best,
Thorsten

You can specify the dimensionality of your input into SyncBatchNorm, so as long as the dimensions are aligned then it will be fine. Here is the follow up issue to track this: BatchNorm runtimeError: one of the variables needed for gradient computation has been modified by an inplace operation · Issue #66504 · pytorch/pytorch · GitHub as well as some additional workarounds.