Can DataParallel work with arbitrary forward() function?

I’ve been using a network with a forward function as below. As shown, the network is not constructed with nn.sequential(), and the foward() function is composed with some operations like: e8 = torch.cat((e6, c3), 1).

I run with netG = torch.nn.DataParallel(netG, device_ids=[0, 1]), to run on multiple GPUs, but error msg shows something like: expected 64 batches, not 32. I understand the dataParallel try to split the full 64 batches among all GPUs. But how can I make a network like below to run on muli-GPUs ?

class G_tconv(nn.Module):
def __init__(self, nc, ngf):
    super(G_tconv, self).__init__()

    self.conv1 = nn.Conv2d(nc, ngf, 4, 2, 1, bias=False)
    self.conv2 = nn.Conv2d(ngf, ngf * 2, 4, 2, 1, bias=False)
    self.batchnorm1 = nn.BatchNorm2d(ngf * 2)
    self.conv3 = nn.Conv2d(ngf * 2, ngf * 4, 4, 2, 1, bias=False)
    self.batchnorm2 = nn.BatchNorm2d(ngf * 4)
    self.conv4 = nn.Conv2d(ngf * 4, ngf * 8, 4, 2, 1, bias=False)
    self.batchnorm3 = nn.BatchNorm2d(ngf * 8)
    self.conv5 = nn.Conv2d(ngf * 8, ngf * 8, 4, 2, 1, bias=False)
    self.batchnorm4 = nn.BatchNorm2d(ngf * 8)

    self.convt1 = nn.ConvTranspose2d(1024, 128, 4, 1, 0, bias=False)

    self.convt2 = nn.ConvTranspose2d(ngf * 8 + 128, ngf * 8, 4, 2, 1, bias=False)
    self.batchnorm5 = nn.BatchNorm2d(ngf * 8)
    self.convt3 = nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False)
    self.batchnorm6 = nn.BatchNorm2d(ngf * 4)
    self.convt4 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False)
    self.batchnorm7 = nn.BatchNorm2d(ngf * 2)
    self.convt5 = nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False)
    self.batchnorm8 = nn.BatchNorm2d(ngf)
    self.convt6 = nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False)

    self.conv_e1 = nn.Conv2d(ngf*8, ngf*2, 1, 1, 0, bias=False)
    self.bn_e1 = nn.BatchNorm2d(ngf*2)

    self.conv_e2 = nn.Conv2d(ngf*2, ngf*8, 3, 1, 1, bias=False)
    self.bn_e2 = nn.BatchNorm2d(ngf*8)


    self.conv_e3 = nn.Conv2d(ngf*4, ngf, 1, 1, 0, bias=False)
    self.bn_e3 = nn.BatchNorm2d(ngf)

    self.conv_e4 = nn.Conv2d(ngf, ngf*4, 3, 1, 1, bias=False)
    self.bn_e4 = nn.BatchNorm2d(ngf*4)


    self.linear = nn.Linear(1024, 128)

def forward(self, batchSize, input1, input2):

    e2 = F.relu(self.conv1(input1))
    e3 = F.relu(self.batchnorm1(self.conv2(e2)))
    e4 = F.relu(self.batchnorm2(self.conv3(e3)))
    e5 = F.relu(self.batchnorm3(self.conv4(e4)))
    e6 = F.relu(self.batchnorm4(self.conv5(e5)))

    c1 = self.linear(input2.view(batchSize, 1024))
    c2 = c1.view(batchSize, 128, 1, 1)
    c3 = c2.expand(batchSize, 128, 4, 4)

    e8 = torch.cat((e6, c3), 1)
    d1_ = F.relu(self.batchnorm5(self.convt2(e8)))
    d1 = F.relu(self.bn_e2(self.conv_e2(F.relu(self.bn_e1(self.conv_e1(d1_))))))
    d2_ = F.relu(self.batchnorm6(self.convt3(d1)))
    d2 = F.relu(self.bn_e4(self.conv_e4(F.relu(self.bn_e3(self.conv_e3(d2_))))))
    d3_ = F.relu(self.batchnorm7(self.convt4(d2)))
    d4_ = F.relu(self.batchnorm8(self.convt5(d3_)))
    d5_ = self.convt6(F.relu(d4_))

    o1 = F.tanh(d5_)

    return o1

when you use DataParallel, just double your batch size. Wouldn’t taht fix the error you are seeing.

this is not general advice, it is just for the model above of @wzhang25

1 Like

Thanks
No luck. When I double the batch size, the error msg changes to: expected 64 batches, not 32
Is DataParallel compatible with arbitrary forward() function?

Hey, DataParallel is compatible with most forward functions, but you seem to be doing something that won’t work. You’re accepting two inputs and a batchSize argument, but note that all instances of the model will get only parts of the batch, while the same batchSize will be sent to all of them. It’d be better if you used e.g. batchSize = input1.size(0) in your forward.

Hi,

class Model():
     def __init__():
          ......

     def forward(input_data):
          ......

     def sample(input_data):
           ......

model = Model()
model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)

model(input_data) can work on multi-gpu,
but mode.module.sample(input_data) work on single gpu.

How to make mode.module.sample(input_data) working on multi-gpu?
I hope you can understand what I mean.
Thanks.

2 Likes

Same problem.

Use forward() can fullfil the need,but I still wonder if you find a way to fix it ??

Thanks