When training locally, it is quite straightforward to simulate a large batch size by calling backward
in a loop, and finally running ´opt.step()`. An example can be found here: deep learning - How to compensate if I cant do a large batch size in neural network - Stack Overflow
How can we replicate this behaviour with distributed autograd and the RPC parameter server paradigm?
Below is a minimal example based on the RPC parameter server example. However, this code gets stuck at the second call to opt.step(cid)
in run_training_loop
. Why is this? Is there a workaround?
import os
from threading import Lock
import time
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.distributed.optim import DistributedOptimizer
from torchvision import datasets, transforms
def run_training_loop(rank, world_size, train_loader):
net = TrainerNet(rank=rank)
num_created = net.param_server_rref.rpc_sync().get_num_models()
while num_created != world_size - 1:
time.sleep(0.5)
num_created = net.param_server_rref.rpc_sync().get_num_models()
param_rrefs = net.get_global_param_rrefs()
opt = DistributedOptimizer(optim.SGD, param_rrefs, lr=0.03)
with dist_autograd.context() as cid:
# accumulated grads over multiple batches
for i, (data, target) in enumerate(train_loader):
model_output = net(data)
target = target.to(model_output.device)
loss = F.nll_loss(model_output, target)
print(f"Rank {rank} step {i} training batch {i} loss {loss.item()}")
dist_autograd.backward(cid, [loss]) # stuck here at second iteration!
# finally apply grads
opt.step(cid)
class ParameterServer(nn.Module):
def __init__(self, model_fn, device):
super().__init__()
self.model_fn = model_fn
self.input_device = device
self.global_model = model_fn().to(self.input_device)
self.models = {}
self.models_init_lock = Lock()
def forward(self, rank, inp):
inp = inp.to(self.input_device)
return self.models[rank](inp).to("cpu")
def get_param_rrefs(self, rank):
return [rpc.RRef(param)
for param in self.models[rank].parameters()]
def create_model_for_rank(self, rank):
with self.models_init_lock:
if rank not in self.models:
self.models[rank] = self.model_fn().to(self.input_device)
self.models[rank].load_state_dict(self.global_model.state_dict())
def get_num_models(self):
with self.models_init_lock:
return len(self.models)
param_server = None
global_lock = Lock()
model_fn = lambda: nn.Sequential(
nn.Conv2d(1, 32, 3, 1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 1),
nn.MaxPool2d(2),
nn.Dropout2d(0.25),
nn.Flatten(),
nn.Linear(9216, 128),
nn.ReLU(),
nn.Dropout2d(0.5),
nn.Linear(128, 10),
nn.LogSoftmax()
)
def get_parameter_server(rank):
global param_server
with global_lock:
if not param_server:
param_server = ParameterServer(model_fn, torch.device("cuda"))
param_server.create_model_for_rank(rank)
return param_server
class TrainerNet(nn.Module):
def __init__(self, rank):
super().__init__()
self.rank = rank
self.param_server_rref = rpc.remote(
"parameter_server", get_parameter_server, args=(
self.rank,))
def get_global_param_rrefs(self):
return self.param_server_rref.rpc_sync().get_param_rrefs(self.rank)
def forward(self, x):
return self.param_server_rref.rpc_sync().forward(self.rank, x)
def start(rank, world_size):
if rank == 0:
rpc.init_rpc(name="parameter_server", rank=rank, world_size=world_size)
else:
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=16, shuffle=True)
rpc.init_rpc(name=f"trainer_{rank}", rank=rank, world_size=world_size)
run_training_loop(rank, world_size, train_loader)
rpc.shutdown()
if __name__ == '__main__':
world_size = 2
master_port = "29500"
master_addr = "localhost"
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = master_port
mp.set_start_method("spawn")
mp.spawn(start, args = (world_size,), nprocs=world_size, join=True)