Split single model in multiple gpus

This is a bit tricky, but is possible.
I’ve created a small code example, which uses model sharing and DataParallel.
It’s using 4 GPUs, where each submodule is split on 2 GPUs as a DataParallel module:

class SubModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SubModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
        
    def forward(self, x):
        print('SubModule, device: {}, shape: {}\n'.format(x.device, x.shape))
        x = self.conv1(x)
        return x


class MyModel(nn.Module):
    def __init__(self, split_gpus, parallel):
        super(MyModel, self).__init__()
        self.module1 = SubModule(3, 6)
        self.module2 = SubModule(6, 1)
        
        self.split_gpus = split_gpus
        self.parallel = parallel
        if self.split_gpus and self.parallel:
            self.module1 = nn.DataParallel(self.module1, device_ids=[0, 1]).to('cuda:0')
            self.module2 = nn.DataParallel(self.module2, device_ids=[2, 3]).to('cuda:2')
        
    def forward(self, x):
        print('Input: device {}, shape {}\n'.format(x.device, x.shape))
        x = self.module1(x)
        print('After module1: device {}, shape {}\n'.format(x.device, x.shape))
        x = self.module2(x)
        print('After module2: device {}, shape {}\n'.format(x.device, x.shape))
        return x


model = MyModel(split_gpus=True, parallel=True)
x = torch.randn(16, 3, 24, 24).to('cuda:0')
output = model(x)

The script will output:

Input: device cuda:0, shape torch.Size([16, 3, 24, 24])

SubModule, device: cuda:0, shape: torch.Size([8, 3, 24, 24])

SubModule, device: cuda:1, shape: torch.Size([8, 3, 24, 24])

After module1: device cuda:0, shape torch.Size([16, 6, 24, 24])

SubModule, device: cuda:2, shape: torch.Size([8, 6, 24, 24])
SubModule, device: cuda:3, shape: torch.Size([8, 6, 24, 24])


After module2: device cuda:2, shape torch.Size([16, 1, 24, 24])

EDIT: As you can see, I just implemented this one use case. So the conditions on self.split_gpu and self.parallel are a bit useless. However, this should give you a starter for your code.

8 Likes