GPU parallelism with nn.parallel.data_parallel

I’m trying to use multi-gpu processing using a code like the following from pytorch dcgan tutorial:

class _netG(nn.Module):
    def __init__(self, ngpu):
        super(_netG, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output

However, in my own code, I need to do reshaping so I cannot use nn.sequential anymore. This way I have written the following code:

class Gen(nn.Module):
    def __init__(self):
        super(Gen, self).__init__()

        # layers
        self.lin = nn.Linear(self.dim_z, 64 * 64)
        self.block1 = ResBlock_G(64 * 8 , 64 * 8)

    def main(self, z, y):
        out = self.lin(z)
        out = out.view(out.size(0), self.ch * 16,
                       self.bottom_width, self.bottom_width)
        out = self.block1(out, y)
        return out

    def forward(self, z, y):
        if isinstance(z.data, torch.cuda.FloatTensor) and self.ngpu > 1:
            output = nn.parallel.data_parallel(
                self.main, (z, y), range(self.ngpu))
        else:
            output = self.main(z, y)
        return output

The problem is that my self.main is not an nn.Module anymore, so in a multi-gpu setting, I get the following error:

Traceback (most recent call last):
  File "/workspace/Spectral_Dynamics/main.py", line 72, in <module>
    fake = gen(z, y)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "/workspace/Spectral_Dynamics/blocks.py", line 309, in forward
    self.main, (z, y), range(self.ngpu))
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 112, in data_parallel
    replicas = replicate(module, used_device_ids)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/nn/parallel/replicate.py", line 10, in replicate
    params = list(network.parameters())
AttributeError: 'function' object has no attribute 'parameters'

Any idea how can I use GPU parallelism in this setting? Or how can I change my model definition?

Thanks

2 Likes

Hey there, try replacing your forward with the following:

    def forward(self, z, y, run_on_ngpu=None):
        run_on_ngpu = run_on_ngpu or self.ngpu
        if isinstance(z.data, torch.cuda.FloatTensor) and run_on_ngpu > 1:
            output = nn.parallel.data_parallel(
                self, (z, y, 1), range(self.ngpu))
        elif run_on_ngpu == 1:
            output = self.main(z, y)
        else:
            raise RuntimeError
        return output

Happy PyTorching!

1 Like

Awesome! Thank you so much!

1 Like