Dist_autograd.context only computes local gradients


I am new to pytorch distributed computing – using pytorch1.9.0. I have a large model that I have distributed over several processes on a single computer. The model is organized such that the master process has an input nn.Linear layer and an output nn.Linear layer that execute at the very start and the very end of the forward pass respectively. Between those two layers are many RemoteModule objects that each execute an RNN model with the overall goal of predicting some time series data. I have successfully obtained rrefs to the parameters associated with each of these RemoteModules as:

MyParameters = []
for module in MyModules:

MyParameters contains 1000 Parameters

Then in the forward pass running on the master process:

for batch in DataLoader:
     with dist_autograd.context() as context_id: 
          input = InputLinear(batch)       
          for t in range(sequenceLength):
               ids =[]    
               for mod in MyModules:
               get_state = [id.wait() for id in ids]
               state[:,t,:] = torch.stack(get_state)
          out = OutputLinear(state)
          loss = nn.MSELoss(out,batch)
          grads = dist_autograd.get_gradients(context_id)
          dist_optim = DistributedOptimizer( optim.Adam, MyParameters ,lr=0.001,)

At this point grads contains only 4 items corresponding to the parameters associated with the InputLinear and OutputLinear layers only. It does not contain gradients for any of the remote module parameters. If I understand properly, the context_id should be able to track computations across these remote modules. Can anyone tell me what I am doing wrong or how I can properly compute the gradients over these remote modules?

Hi @MLbrain thanks for posting the question, as you can see it only have the local_gradients, this is because the gradients will only be visible in the rank where remote module located. When using RPC to do training, you can simply pass the parameter_rrefs into the distributed optimizer, and the distributed optimizer will figure out the remote module parameter gradients and sync it properly. You can take a look at this tutorial Getting Started with Distributed RPC Framework — PyTorch Tutorials 1.9.1+cu102 documentation