I’d like to implement distributed model parallel at the module level such as nn.Linear
. The following snippet divides the module and distributes it to multiple GPUs, but parallel computation is not done.
class DistLinear(nn.Module):
def __init__(self, input_size, output_size, bias=True, io_gpu=0, num_gpus=2):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.bias = bias
self.io_gpu = io_gpu
self.num_gpus = num_gpus
self.setup_modules()
def setup_modules(self):
self.bucket_size = math.ceil(self.output_size / self.num_gpus)
last_bucket_size = self.output_size - self.bucket_size*(self.num_gpus-1)
self.modular = nn.ModuleList([self.get_module(self.bucket_size)
for _ in range(self.num_gpus-1)])
self.modular.append(self.get_module(last_bucket_size))
for i in range(self.num_gpus):
self.modular[i].cuda(i)
def get_module(self, output_size):
return nn.Linear(self.input_size, output_size, bias=self.bias)
def forward(self, x): # x: batch x hidden
xs = []
for i in range(self.num_gpus):
_x = x.cuda(i)
_x = self.modular[i](_x)
_x = _x.cuda(self.io_gpu)
xs.append(_x)
return torch.cat(xs, dim=-1)