RPC parameter server implementation: How to optimize a model on the server without carrying out the forward/backwards call on the server (only send gradients to server)?

I attempted to create my own parameter server implementation where the PS only combines gradients, does the optimization step and some synchronization.

my code is something like this:

worker:

def run_training_loop(rank, num_gpus, train_loader, test_loader, param_server_rref):
    
    model = resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, NUMBER_OF_CLASSES)
    optimizer = optim.SGD(model.parameters(), lr=0.0002, momentum=0.9)

    for epoch in range(NUMBER_OF_EPOCHS):

        for i, (data, target) in enumerate(train_loader):
            
            model_output = model(data)
            loss = F.cross_entropy(model_output, target) 
            loss.backward()

            model = param_server_rref.rpc_sync().optimize(rank, model)

        get_accuracy(test_loader, model)

    print("Training complete!")

server:

def optimize(self, rank, model):
        self.barrier.wait()
  
        if rank == 1:
               self.combined_model = self.combine_model()
               optimizer = optim.SGD(combined_model.parameters(), lr=0.0002, momentum=0.9)
               optimizer.step()
               optimizer.zero_grad()

        self.barrier.wait()
        return copy.deepcopy(self.combined_model)

Unfortunately the optimization step on the paramteter server doesn’t seem to have any effect!

I suspect that autograd is not recording a graph since the optimizer is created after calling the forward/backward functions?!
How can I get this working without doing all computation on the server?

The distributed_optimizer seems to be only useful for model parallelism and not data parallelism since it does not respect gradients from different workers. Is this correct?

Thank you in advance.

1 Like

Hi @MichaelZ thanks for posting the question. It seems to me one of your issue is that optimize method on the server side is creating a new optimizer everytime you run a epoch, this will make a new state for the model params everytime when updating the param, although autograd might recording the gradients. You should create the optimizer once and in optimize just call the step function i think. Also I am not sure what combine_model() is and can’t be sure if it’s recording gradients correctly or not.

The distributed_optimizer seems to be only useful for model parallelism and not data parallelism since it does not respect gradients from different workers. Is this correct?

Distributed optimizer is mainly used by model parallelism as you said, but it does respect the gradients from different workers in the case of model parallelism. For DDP, the gradients are all reduced so it’s just a local optimizer.

Hi @MichaelZ , I am having the same intention on finding out how to run the forward() in each worker instead of on the server. Have you find the solution to this problem? Also, can you share what is in self.combine_model()? Thank you so much

edit: After brief reading the distributed-rpc topics, I found a topic where someone mentioned BATCH RPC PROCESSING. And this is seems to be what we are looking for.

Hi, sorry for the late reply I am currently very busy.
I decided to use the Pytorch c10d library instead, thus I can’t tell you how to do it with the RPC package. I found it quite intuitive to use the p2p and collective functions of c10d. The collective functions also allow you to combine gradients very easily (by summing).

Please let me know in case you need any further details about the c10d library and my implementation.

I hope this doesn’t come to late…

1 Like

Hi, no it’s really not late. After implementing my model using parameter server as well as rpc batch update, but unfortunately with no particular execution time improvement compared to single node hogwild, I think I kinda got the feeling on how to work with pytorch distributed for my use-case. Do you have any other resources for using the functions other than Distributed communication package - torch.distributed — PyTorch 1.9.0 documentation ? Btw, thank you so much for mentioning the c10d, I almost forgot about that because it was too complicated for me when I first read it when I was just learning pytorch distributed.

Hi, if you mean c10d functions than this might be helpful: Writing Distributed Applications with PyTorch — PyTorch Tutorials 1.9.0+cu102 documentation

1 Like