I’ve been using a network with a forward function as below. As shown, the network is not constructed with nn.sequential(), and the foward() function is composed with some operations like: e8 = torch.cat((e6, c3), 1).
I run with netG = torch.nn.DataParallel(netG, device_ids=[0, 1]), to run on multiple GPUs, but error msg shows something like: expected 64 batches, not 32. I understand the dataParallel try to split the full 64 batches among all GPUs. But how can I make a network like below to run on muli-GPUs ?
class G_tconv(nn.Module):
def __init__(self, nc, ngf):
super(G_tconv, self).__init__()
self.conv1 = nn.Conv2d(nc, ngf, 4, 2, 1, bias=False)
self.conv2 = nn.Conv2d(ngf, ngf * 2, 4, 2, 1, bias=False)
self.batchnorm1 = nn.BatchNorm2d(ngf * 2)
self.conv3 = nn.Conv2d(ngf * 2, ngf * 4, 4, 2, 1, bias=False)
self.batchnorm2 = nn.BatchNorm2d(ngf * 4)
self.conv4 = nn.Conv2d(ngf * 4, ngf * 8, 4, 2, 1, bias=False)
self.batchnorm3 = nn.BatchNorm2d(ngf * 8)
self.conv5 = nn.Conv2d(ngf * 8, ngf * 8, 4, 2, 1, bias=False)
self.batchnorm4 = nn.BatchNorm2d(ngf * 8)
self.convt1 = nn.ConvTranspose2d(1024, 128, 4, 1, 0, bias=False)
self.convt2 = nn.ConvTranspose2d(ngf * 8 + 128, ngf * 8, 4, 2, 1, bias=False)
self.batchnorm5 = nn.BatchNorm2d(ngf * 8)
self.convt3 = nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False)
self.batchnorm6 = nn.BatchNorm2d(ngf * 4)
self.convt4 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False)
self.batchnorm7 = nn.BatchNorm2d(ngf * 2)
self.convt5 = nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False)
self.batchnorm8 = nn.BatchNorm2d(ngf)
self.convt6 = nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False)
self.conv_e1 = nn.Conv2d(ngf*8, ngf*2, 1, 1, 0, bias=False)
self.bn_e1 = nn.BatchNorm2d(ngf*2)
self.conv_e2 = nn.Conv2d(ngf*2, ngf*8, 3, 1, 1, bias=False)
self.bn_e2 = nn.BatchNorm2d(ngf*8)
self.conv_e3 = nn.Conv2d(ngf*4, ngf, 1, 1, 0, bias=False)
self.bn_e3 = nn.BatchNorm2d(ngf)
self.conv_e4 = nn.Conv2d(ngf, ngf*4, 3, 1, 1, bias=False)
self.bn_e4 = nn.BatchNorm2d(ngf*4)
self.linear = nn.Linear(1024, 128)
def forward(self, batchSize, input1, input2):
e2 = F.relu(self.conv1(input1))
e3 = F.relu(self.batchnorm1(self.conv2(e2)))
e4 = F.relu(self.batchnorm2(self.conv3(e3)))
e5 = F.relu(self.batchnorm3(self.conv4(e4)))
e6 = F.relu(self.batchnorm4(self.conv5(e5)))
c1 = self.linear(input2.view(batchSize, 1024))
c2 = c1.view(batchSize, 128, 1, 1)
c3 = c2.expand(batchSize, 128, 4, 4)
e8 = torch.cat((e6, c3), 1)
d1_ = F.relu(self.batchnorm5(self.convt2(e8)))
d1 = F.relu(self.bn_e2(self.conv_e2(F.relu(self.bn_e1(self.conv_e1(d1_))))))
d2_ = F.relu(self.batchnorm6(self.convt3(d1)))
d2 = F.relu(self.bn_e4(self.conv_e4(F.relu(self.bn_e3(self.conv_e3(d2_))))))
d3_ = F.relu(self.batchnorm7(self.convt4(d2)))
d4_ = F.relu(self.batchnorm8(self.convt5(d3_)))
d5_ = self.convt6(F.relu(d4_))
o1 = F.tanh(d5_)
return o1