Issue with using DataParallel - RuntimeError: Output 0 of BroadcastBackward is a view and its base or another view of its base has been modified inplace

Hi there, I’m trying to use DataParallel with one model implementation but got this weird error.

My code is roughly like this:

# constructing the NN
net = ...

device = torch.device('cuda:0')
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    net = nn.DataParallel(net)
net.to(device)

# got training data
tr = data.DataLoader(datasets.CIFAR10('../data', train=True, download=True, transform=transforms.ToTensor()),
                     batch_size=128, shuffle=True, num_workers=1, pin_memory=True)

for input_data, _ in tr:
    inp = input_data.unsqueeze(-1).transpose(1, 4)
    inp = inp.to(device)
    output = net(inp)

This threw an exception as follows:

RuntimeError: Output 0 of BroadcastBackward is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.

It says I’m doing some inplace operation which modified certain fields, but I don’t think I’m doing any in-place operations here. What could be wrong?

Hi,

You need to send the net to the right device before wrapping it in DataParallel I think. Otherwise you can get weird behavior like this :slight_smile:

Hi @albanD, thanks for the quick response. I looked at https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html#create-model-and-dataparallel, and it first put the model into DataParallel and then put the model onto device, which is what I did, could you elaborate a bit what I am missing here?

Interesting… Maybe I’m wrong then :smiley: You don’t need to do that.

I am pretty sure though that the first op in your model is an inplace one no?
Can you share the beginning of the forward function for net ?

@albanD sure I’m just working on some pixelCNN tutorial myself so I can share my full code here.

class MaskedConv3d(nn.Conv3d):
    def __init__(self, mask_type, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_buffer("mask", self.weight.clone())
        out_channel, in_channel, h, w, c = self.weight.size()
        self.mask.fill_(1)
        self.mask[:, :, h // 2: , w // 2 + 1:] = 0
        self.mask[:, :, h // 2 + 1:] = 0
        self.mask[:, :, h // 2, w // 2, c // 2 + (mask_type == 'B')] = 0
        
    def forward(self, x):
        with torch.no_grad():
            self.weight *= self.mask
        return super().forward(x)

net = nn.Sequential(
    MaskedConv3d('A', 1, fm, (7, 7, 3), 1, (3, 3, 1), bias=False), nn.BatchNorm3d(fm), nn.ReLU(inplace=True),
    MaskedConv3d('B', fm, fm, (7, 7, 3), 1, (3, 3, 1), bias=False), nn.BatchNorm3d(fm), nn.ReLU(inplace=True),
    MaskedConv3d('B', fm, fm, (7, 7, 3), 1, (3, 3, 1), bias=False), nn.BatchNorm3d(fm), nn.ReLU(inplace=True),
    MaskedConv3d('B', fm, fm, (7, 7, 3), 1, (3, 3, 1), bias=False), nn.BatchNorm3d(fm), nn.ReLU(inplace=True),
    MaskedConv3d('B', fm, fm, (7, 7, 3), 1, (3, 3, 1), bias=False), nn.BatchNorm3d(fm), nn.ReLU(inplace=True),
    MaskedConv3d('B', fm, fm, (7, 7, 3), 1, (3, 3, 1), bias=False), nn.BatchNorm3d(fm), nn.ReLU(True),
    MaskedConv3d('B', fm, fm, (7, 7, 3), 1, (3, 3, 1), bias=False), nn.BatchNorm3d(fm), nn.ReLU(True),
    MaskedConv3d('B', fm, fm, (7, 7, 3), 1, (3, 3,1 ), bias=False), nn.BatchNorm3d(fm), nn.ReLU(True),
    nn.Conv3d(fm, 256 , 1),
)

how do I know if the first op is in-place or not?
is the self.mask.fill_ causing issues?
why isn’t in-place operation allowed here? because you cannot do that on multi-gpu because you cannot change values in place for multi-gpus?

Hi,

So the *= is an inplace operation. So this is the issue.
You might want to do the masking before you forward in the DataParallel to avoid issues.

any suggestions how I would “do the masking before you forward in the DataParallel to avoid issues.”?

would self.weight = self.weight * self.mask work?

Given that you use torch.no_grad(), I guess that you don’t want this to be taken into account in the backprop. So the simplest solution is just to apply this inplace before doing the DataParallel forward.

1 Like

My situation is when multiple gpus are used, only ddp mode works, but not dp. One can try ddp even with for 1 node with multiple gpus.