Torch distributed for Bert Model

I just started with torch distributed framework and as an example to practice on I want to do some model parallelism with the Bert model. To start off with, I wanted to offload the encoder layers to another worker and keep everything else local. Essentially I want to have replicate the following code but in a distributed manner where I would have modules[i] on different worker nodes.

model = BertModel.from_pretrained('bert-base-uncased')
modules = list(model.children())

input_data = torch.zeros(1, 512, dtype=torch.long)

out1 = modules[0](input_data)
out2 = modules[1](handle_output(out1)) 
out3 = modules[2](handle_output(out2))

My handle output function is just a helper that extracts the tensor from the transformer output:

def handle_output(tensor):
    if hasattr(tensor, "last_hidden_state"):
        tensor = tensor.last_hidden_state
    if isinstance(tensor, tuple):
        tensor = tensor[0]
    return tensor

The above code works without issue but when I follow the same structure, I get this error in the forward pass that states I am missing a hidden states variable so I am assuming the RRef object I created for the data isn’t really working properly.

TypeError: TypeError: On WorkerInfo(id=0, name=encoder):
TypeError('On WorkerInfo(id=0, name=encoder):
TypeError("forward() missing 1 required positional argument: 'hidden_states'")
Traceback (most recent call last):
  File "/home/stonks/miniconda3/envs/venv/lib/python3.9/site-packages/torch/distributed/rpc/internal.py", line 207, in _run_function
    result = python_udf.func(*python_udf.args, **python_udf.kwargs)
  File "/home/stonks/miniconda3/envs/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/stonks/miniconda3/envs/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: forward() missing 1 required positional argument: 'hidden_states'

Below is my RPC code

def _parameter_rrefs(module):
    param_rrefs = []
    for param in module.parameters():
        param_rrefs.append(RRef(param))
    return param_rrefs

def _call_method(method, rref, *args, **kwargs):
    r"""
    a helper function to call a method on the given RRef
    """
    return method(rref.local_value(), *args, **kwargs)


def _remote_method(method, rref, *args, **kwargs):
    r"""
    a helper function to run method on the owner of rref and fetch back the
    result using RPC
    """
    args = [method, rref] + list(args)
    return rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs)

def handle_output(tensor):
    if hasattr(tensor, "last_hidden_state"):
        tensor = tensor.last_hidden_state
    if isinstance(tensor, tuple):
        tensor = tensor[0]
    return tensor



class Distributed(nn.Module):
    def __init__(self, model):
        super(Distributed, self).__init__()
        
        self.model = model
        self.submodules = list(model.children())

        self.emb = self.submodules[0]
        self.encoder_rref = rpc.remote('encoder', self.submodules[1])
        self.pooler = self.submodules[2]
        print('Model constructed')

    def forward(self, input_rref):
        out_emb = self.emb(input_rref.to_here())
        out_encoder = handle_output(_remote_method(self.submodules[1], self.encoder_rref, out_emb))
        out_pooled = self.pooler(out_encoder)
        return out_pooled

    def parameter_rrefs(self):
        remote_params = []
        remote_params.extend(_parameter_rrefs(self.emb))
        remote_params.extend(_remote_method(_parameter_rrefs, self.encoder_rref))
        remote_params.extend(_parameter_rrefs(self.pooler))
        return remote_params
            

def train_test():
    # worker 0
    bert = BertModel.from_pretrained("bert-base-uncased")
    model = Distributed(bert)
    print('done')

    # opt = DistributedOptimizer(
    #     torch.optim.Adam,
    #     model.parameter_rrefs(),
    # )
    
    dummy_input = torch.zeros((1, 1), dtype=torch.long)

    dummy_rref = RRef(dummy_input)
    print(dummy_rref)
    output_ref = model(dummy_rref)
    output = output_ref.to_here()
    print(output)
    
    # loss = output[0].sum()
    # dist_autograd.backward([loss])
    # opt.step()

def run_worker(rank, world_size):

    rpc.init_rpc("encoder", rank=rank, world_size=world_size)
    train_test()

    rpc.shutdown()
    

if __name__ == "__main__":
    world_size = 1
    mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)

If anyone would be able to help on this would very much be appreciated, and I am hoping to get the backward pass and optimizer working too at some point so some tips on that end would be nice as well. Also just wanted general advice as I am finding the distributed stuff a bit difficult and having quite a big learning curve even with good prior pytorch experience. What would be the best way to get proficient in this framework? Thanks in advance