RPC parameter server - accumulated gradients with multiple calls to dist_autograd.backward?

When training locally, it is quite straightforward to simulate a large batch size by calling backward in a loop, and finally running ´opt.step()`. An example can be found here: deep learning - How to compensate if I cant do a large batch size in neural network - Stack Overflow

How can we replicate this behaviour with distributed autograd and the RPC parameter server paradigm?

Below is a minimal example based on the RPC parameter server example. However, this code gets stuck at the second call to opt.step(cid) in run_training_loop. Why is this? Is there a workaround?

import os
from threading import Lock
import time
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.distributed.optim import DistributedOptimizer
from torchvision import datasets, transforms

def run_training_loop(rank, world_size, train_loader):
    net = TrainerNet(rank=rank)
    num_created = net.param_server_rref.rpc_sync().get_num_models()
    while num_created != world_size - 1:
        time.sleep(0.5)
        num_created = net.param_server_rref.rpc_sync().get_num_models()

    param_rrefs = net.get_global_param_rrefs()
    opt = DistributedOptimizer(optim.SGD, param_rrefs, lr=0.03)
    with dist_autograd.context() as cid:
        # accumulated grads over multiple batches
        for i, (data, target) in enumerate(train_loader):
            model_output = net(data)
            target = target.to(model_output.device)
            loss = F.nll_loss(model_output, target)
            print(f"Rank {rank} step {i} training batch {i} loss {loss.item()}")
            dist_autograd.backward(cid, [loss]) # stuck here at second iteration!
        # finally apply grads
        opt.step(cid)

class ParameterServer(nn.Module):
    def __init__(self, model_fn, device):
        super().__init__()
        self.model_fn = model_fn
        self.input_device = device
        self.global_model = model_fn().to(self.input_device)
        self.models = {}
        self.models_init_lock = Lock()

    def forward(self, rank, inp):
        inp = inp.to(self.input_device)
        return self.models[rank](inp).to("cpu")

    def get_param_rrefs(self, rank):
        return [rpc.RRef(param)
                       for param in self.models[rank].parameters()]

    def create_model_for_rank(self, rank):
        with self.models_init_lock:
            if rank not in self.models:
                self.models[rank] = self.model_fn().to(self.input_device)
                self.models[rank].load_state_dict(self.global_model.state_dict())

    def get_num_models(self):
        with self.models_init_lock:
            return len(self.models)

param_server = None
global_lock = Lock()
model_fn = lambda: nn.Sequential(
                nn.Conv2d(1, 32, 3, 1),
                nn.ReLU(),
                nn.Conv2d(32, 64, 3, 1),
                nn.MaxPool2d(2),
                nn.Dropout2d(0.25),
                nn.Flatten(),
                nn.Linear(9216, 128),
                nn.ReLU(),
                nn.Dropout2d(0.5),
                nn.Linear(128, 10),
                nn.LogSoftmax()
            )

def get_parameter_server(rank):
    global param_server
    with global_lock:
        if not param_server:
            param_server = ParameterServer(model_fn, torch.device("cuda"))
        param_server.create_model_for_rank(rank)
        return param_server

class TrainerNet(nn.Module):
    def __init__(self, rank):
        super().__init__()
        self.rank = rank
        self.param_server_rref = rpc.remote(
            "parameter_server", get_parameter_server, args=(
                self.rank,))

    def get_global_param_rrefs(self):
        return self.param_server_rref.rpc_sync().get_param_rrefs(self.rank)

    def forward(self, x):
        return self.param_server_rref.rpc_sync().forward(self.rank, x)

def start(rank, world_size):
    if rank == 0:
        rpc.init_rpc(name="parameter_server", rank=rank, world_size=world_size)
    else:
        train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])), batch_size=16, shuffle=True)
        rpc.init_rpc(name=f"trainer_{rank}", rank=rank, world_size=world_size)
        run_training_loop(rank, world_size, train_loader)
    rpc.shutdown()

if __name__ == '__main__':
    world_size = 2
    master_port = "29500"
    master_addr = "localhost"
    os.environ['MASTER_ADDR'] = master_addr
    os.environ['MASTER_PORT'] = master_port
    mp.set_start_method("spawn")
    mp.spawn(start, args = (world_size,), nprocs=world_size, join=True)

Hi, by looking at your code, it looks like the comment says it get stuck at “dist_autograd.backward(cid, [loss])”?

Do you mind sharing your log here?
Also I tried to repo locally myself but I got errors here. Can you kindly share your environment so that I can repo? Like your OS, pytorch version, etc.

Thanks!

That’s right, I verified this using a debugger, and the second call to backward never returns. Example output:

Rank 1 step 0 training batch 0 loss 2.2719762325286865
Rank 1 step 1 training batch 1 loss 2.242600679397583

PyTorch version: 1.12.1+cu116
Python version: 3.8.10
System: Ubuntu Linux 20.04 - Kernel 5.15.0-46

If you get an RPC error saying the address is already in use, changing the value of master_port may help.

I got a different error… OK, by looking at the wiki again and looks like it is recommended to have the rank to be passed in so that you can run two or many processes on different machines. How do you run your program locally?

I place the script included in my original post in a file, say my_script.py and run it from the command line via python3 my_script.py.

cc: @Rohan_Varma who has more context on parameter server.