Hello,
I am new to pytorch distributed computing – using pytorch1.9.0. I have a large model that I have distributed over several processes on a single computer. The model is organized such that the master process has an input nn.Linear layer and an output nn.Linear layer that execute at the very start and the very end of the forward pass respectively. Between those two layers are many RemoteModule objects that each execute an RNN model with the overall goal of predicting some time series data. I have successfully obtained rrefs to the parameters associated with each of these RemoteModules as:
MyModules.append(RemoteModule("worker"+str(proc),MyRNN))
MyParameters = []
for module in MyModules:
MyParameters.extend(module.remote_parameters())
MyParameters
contains 1000 Parameters
Then in the forward pass running on the master process:
for batch in DataLoader:
with dist_autograd.context() as context_id:
input = InputLinear(batch)
for t in range(sequenceLength):
ids =[]
for mod in MyModules:
ids.append(mod.forward_async(input,t)
get_state = [id.wait() for id in ids]
state[:,t,:] = torch.stack(get_state)
out = OutputLinear(state)
loss = nn.MSELoss(out,batch)
dist_autograd.backward(context_id,[loss])
grads = dist_autograd.get_gradients(context_id)
dist_optim = DistributedOptimizer( optim.Adam, MyParameters ,lr=0.001,)
dist_optim.step(context_id)
At this point grads
contains only 4 items corresponding to the parameters associated with the InputLinear
and OutputLinear
layers only. It does not contain gradients for any of the remote module parameters. If I understand properly, the context_id
should be able to track computations across these remote modules. Can anyone tell me what I am doing wrong or how I can properly compute the gradients over these remote modules?