Here’s the thing. Assuming that I have 4 gpus, ‘gpu0’, ‘gpu1’, ‘gpu2’ and ‘gpu3’. I need to train one network(net1) but the input to the net1 is provided by another network(net2). I want to use all my gpus to do the training.
So first , I use two processes in my program, process1 for performing the forward step of net2 to get the inputs of net1, then using torch.distributed.broadcast to send the input to process2. Process2 receives the input sent by process1 then perform the forward and backward step for training net1.
In process1, I use torch.nn.parallel.DistributedDataParallel to put net1 on ‘gpu0’ and ‘gpu1’ .
In process2, I use torch.nn.parallel.DistributedDataParallel to put net2 on ‘gpu2’ and ‘gpu3’.
Here is how I implement a toy model:
from torch.multiprocessing import Process
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
class Random_Dataset(Dataset):
def __init__(self, num=1200, dim=300):
self.num = num
self.dim = dim
self.data = torch.rand(num, dim)
def __len__(self):
return self.num
def __getitem__(self, idx):
return self.data[idx]
class Model(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(Model, self).__init__()
self.rnn = nn.LSTM(
input_size=input_size, hidden_size=hidden_size,
num_layers=num_layers, batch_first=True
)
def forward(self, x, lengths):
total_length = x.size(1)
self.rnn.flatten_parameters()
x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True)
outputs, (h_n, c_n) = self.rnn(x)
outputs, _ = nn.utils.rnn.pad_packed_sequence(
outputs, batch_first=True, total_length=total_length
)
return outputs
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(Generator, self).__init__()
padding = int((kernel_size - 1) / 2)
self.conv = nn.Conv1d(
in_channels, out_channels, kernel_size,
padding=padding
)
def forward(self, x):
x = torch.unsqueeze(x, 1)
x = self.conv(x)
x = F.sigmoid(x)
return x
def init_process(rank, world_size, backend, func, params):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
group = dist.new_group([0, 1])
func(*params, group)
def generate_data(dataloader, num_epochs, rank, device_ids, network_params,
use_cuda, group):
generator = Generator(*network_params)
if use_cuda:
generator = generator.cuda()
print("Network initialized!")
generator = nn.parallel.distributed.DistributedDataParallel(
generator, device_ids=device_ids
)
for epoch in range(num_epochs):
for i_batch, batch in enumerate(dataloader):
batch = batch.cuda()
rnn_input = generator(batch)
dist.broadcast(rnn_input, rank, group)
print("epoch:{}, batch_num:{}, broadcast finished!".format(
epoch, i_batch
)
)
def run_rnn(num_batchs, num_epochs, src_rank, device_ids, network_params,
input_size, use_cuda, group):
rnn = Model(*network_params)
if use_cuda:
rnn = rnn.cuda()
print("Network initialized!")
rnn = nn.parallel.distributed.DistributedDataParallel(
rnn, device_ids=device_ids
)
optimizer = optim.Adam(rnn.parameters(), lr=1e-4)
for epoch in range(num_epochs):
for i_batch in range(num_batchs):
optimizer.zero_grad()
rnn_input = torch.Tensor(input_size).cuda()
dist.broadcast(rnn_input, src_rank, group)
print("epoch:{}, batch_num:{}, receive finished!".format(
epoch, i_batch
)
)
batch_size = rnn_inputs.size(0)
lengths = np.random.randint(low=3, high=30, size=(batch_size))
lengths = -np.sort(-lengths)
lengths = torch.from_numpy(lengths).long().cuda()
out = rnn(rnn_input, lengths)
out = torch.sum(out)
out.backward()
optimizer.step()
print("out:{}".format(out.item()))
rnn = rnn.cpu()
torch.save('rnn.net', rnn.state_dict)
def main(use_cuda=False):
world_size = 2
processes = []
dataset = Random_Dataset()
dataloader = DataLoader(
dataset, batch_size=12, shuffle=True, num_workers=1
)
num_epochs = 2
num_batchs = 100
generator_device_ids = [0, 1]
rnn_device_ids = [2, 3]
generator_params = (1, 50, 5)
rnn_params = (300, 300, 3)
p1 = Process(
target=init_process,
args=(0, world_size, 'nccl', generate_data,
(dataloader, num_epochs, 0, generator_device_ids,
generator_params, use_cuda)
)
)
p1.start()
processes.append(p1)
p2 = Process(
target=init_process,
args=(1, world_size, 'gloo', run_rnn,
(num_batchs, num_epochs, 0,
rnn_device_ids, rnn_params, torch.Size((12, 50, 300)),
use_cuda)
)
)
p2.start()
processes.append(p2)
for p in processes:
p.join()
if __name__ == '__main__':
main(True)
The code hangs when call “torch.nn.parallel.DistributedDataParallel”.
Clearly two networks in above code is small enough to fit in just one gpu, but if I want the code to run like what I described at first, how should I modify my code?