Distributed model parallelism

I want to realise the distributed model parallelism with PyTorch, but I cannot find any example for this. I just find some distributed data parallelism example. My problem is that, if I divide a computing graph into two nodes, how can I still continue to use autograd to computing the gradients and update the weights?

Any help will be appreciate.

Hi! There’s no way to make autograd aware of the fact that your model is spread across multiple nodes (we have a distributed mode like this, but it needs some work to be completed, and it will take some time to get there).

We don’t have an official example of model parallelism, however it would be sth along the lines of this (the code is incomplete and might contain minor errors, but you get the idea):

import torch.distributed as dist

dist.init_process_group(...)
if dist.rank == 0:
  model = ... # create first part
else:
  model = ... # create second part

if dist.rank == 0:
  for data, labels in data_loader:
    optimizer.zero_grad()
    data = Variable(data)
    output = model(data)
    dist.send(output.data)
    # Now, we're done with our forward part. Wait for gradients
    dist.recv(grad_output)
    output.backward(grad_output)
    optimizer.step() # Do optimization on first part
else:
  for data, labels in data_loader:
    optimizer.zero_grad()
    labels = Variable(labels)
    dist.recv(first_part_output) # Wait for output of first part of the model

    # Compute the final part and the derivatives of this part
    input = Variable(first_part_output, requires_grad=True)
    output = model(input)
    loss = loss_fn(output, labels)
    loss.backward()
    dist.send(input.grad.data) # Send gradients back to node 0
    optimizer.step() # Do optimization on second part
1 Like

Thank you. I get how to do in your code. And I just revised the MNIST example into this mode. I find that the model can get the same accuracy like original moist example. I think it can work. If u find it is not right, please tell me.

def train(epoch):
    if dist.get_rank() == 0:
        model.train()
        model.cuda()
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
        for batch_idx, (data, target) in enumerate(train_loader):
            input_from_part2 = torch.FloatTensor(data.size()[0], 320)
            optimizer.zero_grad()
            data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)
            output = model(data)
            dist.send(output.data.cpu(), dst=1)

            dist.recv(tensor=input_from_part2, src= 1)
            output.backward(input_from_part2.cuda())
            optimizer.step()
    else:
        model.train()
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
        for batch_idx, (data, target) in enumerate(train_loader):
            output_from_part1 = torch.FloatTensor(data.size()[0], 320)
            optimizer.zero_grad()
            target = Variable(target)

            dist.recv(tensor=output_from_part1, src=0)
            input = Variable(output_from_part1, requires_grad = True)
            output = model(input)
            loss = F.nll_loss(output, target)
            loss.backward()
            dist.send(input.grad.data, dst=0)
            optimizer.step()

            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.data[0]))

The code looks good! I’ve cleaned it up a bit:

def train(epoch):
    model.train()
    model.cuda()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    if dist.get_rank() == 0:
        input_from_part2 = torch.FloatTensor(data.size()[0], 320)
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            data = Variable(data.cuda())
            output = model(data)
            dist.send(output.data.cpu(), dst=1)

            dist.recv(tensor=input_from_part2, src= 1)
            output.backward(input_from_part2.cuda())
            optimizer.step()
    else:
        output_from_part1 = torch.FloatTensor(data.size()[0], 320)
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            target = Variable(target)

            dist.recv(tensor=output_from_part1, src=0)
            input = Variable(output_from_part1, requires_grad = True)
            output = model(input)
            loss = F.nll_loss(output, target)
            loss.backward()
            dist.send(input.grad.data, dst=0)
            optimizer.step()

            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.data[0]))
3 Likes