I just started with torch distributed framework and as an example to practice on I want to do some model parallelism with the Bert model. To start off with, I wanted to offload the encoder layers to another worker and keep everything else local. Essentially I want to have replicate the following code but in a distributed manner where I would have modules[i]
on different worker nodes.
model = BertModel.from_pretrained('bert-base-uncased')
modules = list(model.children())
input_data = torch.zeros(1, 512, dtype=torch.long)
out1 = modules[0](input_data)
out2 = modules[1](handle_output(out1))
out3 = modules[2](handle_output(out2))
My handle output function is just a helper that extracts the tensor from the transformer output:
def handle_output(tensor):
if hasattr(tensor, "last_hidden_state"):
tensor = tensor.last_hidden_state
if isinstance(tensor, tuple):
tensor = tensor[0]
return tensor
The above code works without issue but when I follow the same structure, I get this error in the forward pass that states I am missing a hidden states variable so I am assuming the RRef object I created for the data isn’t really working properly.
TypeError: TypeError: On WorkerInfo(id=0, name=encoder):
TypeError('On WorkerInfo(id=0, name=encoder):
TypeError("forward() missing 1 required positional argument: 'hidden_states'")
Traceback (most recent call last):
File "/home/stonks/miniconda3/envs/venv/lib/python3.9/site-packages/torch/distributed/rpc/internal.py", line 207, in _run_function
result = python_udf.func(*python_udf.args, **python_udf.kwargs)
File "/home/stonks/miniconda3/envs/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/stonks/miniconda3/envs/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
TypeError: forward() missing 1 required positional argument: 'hidden_states'
Below is my RPC code
def _parameter_rrefs(module):
param_rrefs = []
for param in module.parameters():
param_rrefs.append(RRef(param))
return param_rrefs
def _call_method(method, rref, *args, **kwargs):
r"""
a helper function to call a method on the given RRef
"""
return method(rref.local_value(), *args, **kwargs)
def _remote_method(method, rref, *args, **kwargs):
r"""
a helper function to run method on the owner of rref and fetch back the
result using RPC
"""
args = [method, rref] + list(args)
return rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs)
def handle_output(tensor):
if hasattr(tensor, "last_hidden_state"):
tensor = tensor.last_hidden_state
if isinstance(tensor, tuple):
tensor = tensor[0]
return tensor
class Distributed(nn.Module):
def __init__(self, model):
super(Distributed, self).__init__()
self.model = model
self.submodules = list(model.children())
self.emb = self.submodules[0]
self.encoder_rref = rpc.remote('encoder', self.submodules[1])
self.pooler = self.submodules[2]
print('Model constructed')
def forward(self, input_rref):
out_emb = self.emb(input_rref.to_here())
out_encoder = handle_output(_remote_method(self.submodules[1], self.encoder_rref, out_emb))
out_pooled = self.pooler(out_encoder)
return out_pooled
def parameter_rrefs(self):
remote_params = []
remote_params.extend(_parameter_rrefs(self.emb))
remote_params.extend(_remote_method(_parameter_rrefs, self.encoder_rref))
remote_params.extend(_parameter_rrefs(self.pooler))
return remote_params
def train_test():
# worker 0
bert = BertModel.from_pretrained("bert-base-uncased")
model = Distributed(bert)
print('done')
# opt = DistributedOptimizer(
# torch.optim.Adam,
# model.parameter_rrefs(),
# )
dummy_input = torch.zeros((1, 1), dtype=torch.long)
dummy_rref = RRef(dummy_input)
print(dummy_rref)
output_ref = model(dummy_rref)
output = output_ref.to_here()
print(output)
# loss = output[0].sum()
# dist_autograd.backward([loss])
# opt.step()
def run_worker(rank, world_size):
rpc.init_rpc("encoder", rank=rank, world_size=world_size)
train_test()
rpc.shutdown()
if __name__ == "__main__":
world_size = 1
mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)
If anyone would be able to help on this would very much be appreciated, and I am hoping to get the backward pass and optimizer working too at some point so some tips on that end would be nice as well. Also just wanted general advice as I am finding the distributed stuff a bit difficult and having quite a big learning curve even with good prior pytorch experience. What would be the best way to get proficient in this framework? Thanks in advance