I’d like to implement distributed model parallel at the module level such as nn.Linear. The following snippet divides the module and distributes it to multiple GPUs, but parallel computation is not done.
class DistLinear(nn.Module):
def __init__(self, input_size, output_size, bias=True, io_gpu=0, num_gpus=2):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.bias = bias
self.io_gpu = io_gpu
self.num_gpus = num_gpus
self.setup_modules()
def setup_modules(self):
self.bucket_size = math.ceil(self.output_size / self.num_gpus)
last_bucket_size = self.output_size - self.bucket_size*(self.num_gpus-1)
self.modular = nn.ModuleList([self.get_module(self.bucket_size)
for _ in range(self.num_gpus-1)])
self.modular.append(self.get_module(last_bucket_size))
for i in range(self.num_gpus):
self.modular[i].cuda(i)
def get_module(self, output_size):
return nn.Linear(self.input_size, output_size, bias=self.bias)
def forward(self, x): # x: batch x hidden
xs = []
for i in range(self.num_gpus):
_x = x.cuda(i)
_x = self.modular[i](_x)
_x = _x.cuda(self.io_gpu)
xs.append(_x)
return torch.cat(xs, dim=-1)