DataParallel breaks the forward() method

Hi,

When I wrap my model in DataParallel my forward() method no longer gets input of the right size. For example

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x, y):
        return x, y

net = Net()

device = torch.device('cuda')
net = nn.DataParallel(net)

net = net.to(device)

x = torch.ones((1024, 3, 100, 100))
y = torch.ones((1, 3, 100, 100))

x_out, y_out = net(x, y)

print(x_out.shape)
#torch.Size([128, 3, 100, 100])

I have 8 GPUs so it seems like I only get back the result from one of them. If I change the first dimension of y to match x I get back the original result, as expected.

My guess is something is strange about trying to coalesce y across the “batch” dimension when it is getting passed back. Does working around this by removing the first dimension of y (e.g., y = y.squeeze(0)) before passing it to the net work?

No, in that case I get torch.Size([384, 3, 100, 100])

@VitalyFedyunin for DataParallel question.

If you add some debug statements to the forward method you’ll see that due to the second input, the processing fails, since it cannot be chunked into 8 parts:

class Net(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        print('x: {}, {}'.format(x.device, x.shape))
        print('y: {}, {}'.format(y.device, y.shape))
        return x, y

Your setup:

x = torch.ones((1024, 3, 100, 100))
y = torch.ones((1, 3, 100, 100))

x_out, y_out = net(x, y)

print(x_out.shape)

x: cuda:0, torch.Size([128, 3, 100, 100])
y: cuda:0, torch.Size([1, 3, 100, 100])
torch.Size([128, 3, 100, 100])

I’m currently unsure what the expected behavior is, as it currently seems to fall back to the smallest possible splitting (I would assume an error would be raised).

If you are using tensors with the same shape in dim0, it’s working as expected:

x = torch.ones((1024, 3, 100, 100))
y = torch.ones((1024, 3, 100, 100))

x_out, y_out = net(x, y)

print(x_out.shape)

x: cuda:0, torch.Size([128, 3, 100, 100])
y: cuda:0, torch.Size([128, 3, 100, 100])
x: cuda:1, torch.Size([128, 3, 100, 100])
y: cuda:1, torch.Size([128, 3, 100, 100])
x: cuda:2, torch.Size([128, 3, 100, 100])
x: cuda:3, torch.Size([128, 3, 100, 100])
y: cuda:2, torch.Size([128, 3, 100, 100])
x: cuda:4, torch.Size([128, 3, 100, 100])
x: cuda:5, torch.Size([128, 3, 100, 100])
y: cuda:4, torch.Size([128, 3, 100, 100])
y: cuda:5, torch.Size([128, 3, 100, 100])
x: cuda:6, torch.Size([128, 3, 100, 100])
x: cuda:7, torch.Size([128, 3, 100, 100])
y: cuda:6, torch.Size([128, 3, 100, 100])
y: cuda:7, torch.Size([128, 3, 100, 100])
y: cuda:3, torch.Size([128, 3, 100, 100])
torch.Size([1024, 3, 100, 100])
1 Like