Passing results between workers using RPC framework

Dear @mrshenli,

I have successfully run the RRN RPC tutorial/example shown in the link below:

https://pytorch.org/tutorials/intermediate/rpc_tutorial.html

This is very helpful for me to understand the basics of the RPC framework and has demonstrated it very clear how to do distributed training with two machines (nodes).
However, my question is what if we have 3 or more nodes (workers) and we want to split the model into submodels on each machine/node/worker. How should I use the RPC framework to help to pass the intermediate result of the previous worker to the next worker?

Take the RNN tutorial/model as an example:

In the tutorial, basically it does this in the forward pass:
def forward(self, input, hidden):
# pass input to the remote embedding table and fetch emb tensor back
emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input)
output, hidden = self.rnn(emb, hidden)
# pass output to the rremote decoder and get the decoded output back
decoded = _remote_method(Decoder.forward, self.decoder_rref, output)
return decoded, hidden

  1. By using the rpc.remote and rpc_sync, I can have the EmbddingTable and do the forward pass remotely on the worker#1 and get the EmbddingTable’s result back locally.

  2. Then I pass the EmbddingTable’s result to my local RNN (worker#0) and get the corresponding RNN result.

  3. I have another Decoder remotely on worker#1 again, and I push the RNN result to that Decoder and then get the result back by using rpc_sync

However, what if I have three workers, worker#0 (local), worker#1 and worker#2. But this time, I put the RNN model on remote worker#2 and I want to have the communication like below:

  1. From worker#0, I push the input to the EmbddingTable on worker#1. After worker#1 finishes the calculation, it passes the result to the RNN on worker#2.

  2. The RNN on worker#2 calculates and passes the result back to worker#1 for the Decoder.

  3. After the Decoder on worker#1 finishes the computation, I (on worker 0) use rpc_sync or to_here to get the final result back to local.

Would you think this is possible and let me know how to do this? Besides, can the Distributed Autograd and the Distributed Optimizer still be applied in this scenario?

One of my thoughts is that if I could make an RRef to each submodule’s output and pass them among the workers.

Thank you very much in advance for your time and help!

Best,
Ziyi

Hey @ZiyiZhu, thanks for trying out RPC.

Would you think this is possible and let me know how to do this?

One of my thoughts is that if I could make an RRef to each submodule’s output and pass them among the workers.

Solution I

Yes, this is possible, and using RRef is certainly one proper solution. More specifically, we can let worker 0 serve as a master here. Sth like

# On worker 0
emb_lookup_rref = rpc.remote("worker1", EmbeddingTable.forward, args=(input,))
# note that RNN.forward needs to call to_here() on emb_lookup_rref
rnn_result_rref = rpc.remote("worker2", RNN.forward, args=(emb_lookup_rref,))
# similarly Decoder also needs to call to_here() on rnn_result_rref
decoded_rref = rpc.remote("worker1", Decoder.forward, args=(rnn_result_rref,))
final_result = decoded_rref.to_here()

Above should work. Although it would result in several additional light-weight messages to manage internal RRef reference count, it shouldn’t slow down training. Because rpc.remote is non-blockng, it returns an RRef immediately. It’s very likely that the RNN module is already waiting on to_here() to get the embedding lookup result even before the EmbeddingTable finished processing the request. So that there shouldn’t be noticeable delay on the critical path.

Solution II

An alternative solution is to use nested RPC. You can wrap the EmbeddingTable → RNN → Decoder into one module forward function (Say MyModel.forward), and then let worker 0 to run rpc.rpc_sync("worker1", MyModel.forward, args=(input,)). Within MyModel.forward, it can also use rpc/remote to communicate with worker 2, sth like:

class MyModel(nn.Module):
    def forward(self, input):
        lookup_result = self.emb(input)
        # here is directly pass the lookup result instead of RRef to RNN
        rnn_result = rpc.rpc_sync("worker2", RNN.forward, inputs=(lookup_result))
        return self.decoder(rnn_result)

Sorry I missed this. Yes, both of them would still work in this case, as long as you wrap the top-level (not nested) RPCs with distributed autograd context. All RPCs originated from that context will propagate the context id, so that autograd and optimizer on different workers will be able to find the context properly.

For Distributed Optimizer, as long as 1) you provide a correct list of param RRefs to its constructor 2) its step() function is wrapped by the correct dist autograd context, it should work. It does not care where those parameters live.

BTW, in the next release, we are making dist autograd/optimizer functional. They will take the context id as an argument and does not need to be wrapped by a with context statement anymore.

Hi @mrshenli,

Thank you very much for the solutions. I have tried the first solution quickly and it can work! I will try it with the Distributed Autograd and Optimizer once I construct the entire network for training. However, one of my concerns is that when we train the network, there are many iterations and epochs, which means we will have lots of forwarding passes of the following:

# On worker 0
emb_lookup_rref = rpc.remote("worker1", EmbeddingTable.forward, args=(input,))
# note that RNN.forward needs to call to_here() on emb_lookup_rref
rnn_result_rref = rpc.remote("worker2", RNN.forward, args=(emb_lookup_rref,))
# similarly Decoder also needs to call to_here() on rnn_result_rref
decoded_rref = rpc.remote("worker1", Decoder.forward, args=(rnn_result_rref,))
final_result = decoded_rref.to_here()

Tons of RRefs (emb_lookup_rref , rnn_result_rref , and decoded_rref ) will be created by the rpc.remote . Should I worry about this? Or these RRefs will be deconstructed automatically?

Thank you!!

RPC will automatically track RRef reference count. This describes the algorithm. The object referenced by the RRef will be deleted automatically when the reference count drops to 0. So, they should be deleted automatically, and we saw it works correctly in intensive training applications. One thing I want to mention is that this relies on Python GC to delete vars like emb_lookup_rref in time though, which should be the case if there is no circular reference that points to the RRef. Let us know if you hit OOM due to RRef. We can probably expose the deletion APIs explicitly if necessary.

1 Like

Dear @mrshenli,

I tested the RPC framework with two nodes for a model parallelism implementation. The distributed Autograd and Optimizer can work successfully as the way I constructed them following the template in the RPC tutorial https://pytorch.org/tutorials/intermediate/rpc_tutorial.html. However, I do see the memory problem in the GPU and the memory usage grows with the number of epochs. I wonder if you could let me know what could be the problem.

I constructed a very simple CNN for the classification of the FashionMNIST dataset. Then I divided it into two submodels, one for convolutional layers and the other for fully-connected layers as below:

class ConvNet(nn.Module):

    def __init__(self, device):
        super().__init__()
        self.device = device
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5).to(self.device)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5).to(self.device)

    def forward(self, rref):
        t = rref.to_here().to(self.device)
        # conv 1
        t = self.conv1(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        # conv 2
        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        return t.cpu()
    
class FCNet(nn.Module):

    def __init__(self,device):
        super().__init__()
        self.device = device
        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120).to(self.device)
        self.fc2 = nn.Linear(in_features=120, out_features=60).to(self.device)
        self.out = nn.Linear(in_features=60, out_features=10).to(self.device)

    def forward(self, rref):
       
        t = rref.to_here().to(self.device)

        # fc1
        t = t.reshape(-1, 12*4*4)
        t = self.fc1(t)
        t = F.relu(t)

        # fc2
        t = self.fc2(t)
        t = F.relu(t)

        # output
        t = self.out(t)
        # don't need softmax here since we'll use cross-entropy as activation.

        return t.cpu()

To wrap them up, I created another CNNModel class for the purpose and perform the forward pass:

class CNNModel(nn.Module):
    def __init__(self, connet_wk, fcnet_wk, device):
        super(CNNModel, self).__init__()

        # setup embedding table remotely
        self.device = device
        
        self.convnet_rref = rpc.remote(connet_wk, ConvNet,args=(device,))
        # setup LSTM locally
        print(self.convnet_rref.to_here())
        self.fcnet_rref = rpc.remote(fcnet_wk, FCNet,args=(device,))
        print(self.fcnet_rref.to_here())
        print('CNN model constructed: ' + 'owner')


    def forward(self, inputreff):
        
        convnet_forward_rref = rpc.remote(self.convnet_rref.owner(), _call_method, args=(ConvNet.forward, self.convnet_rref, inputreff))
        
        fcnet_forward_rref = rpc.remote(self.fcnet_rref.owner(), _call_method, args=(FCNet.forward, self.fcnet_rref, convnet_forward_rref))
                                                                    
        return fcnet_forward_rref
    
    def parameter_rrefs(self):
        remote_params = []
        remote_params.extend(_remote_method(_parameter_rrefs, self.convnet_rref))
        remote_params.extend(_remote_method(_parameter_rrefs, self.fcnet_rref))
        return remote_params

For training, I have a trainer to do that using Distributed Autograd and Optimiser:


class Trainer(object):

    def __init__(self, model, optimizer, train_loader, test_loader, device):
        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device

    def fit(self, epochs):
        for epoch in range(1, epochs + 1):
            train_loss, train_acc = self.train()
            test_loss, test_acc = self.evaluate()

            print(
                'Epoch: {}/{},'.format(epoch, epochs),
                'train loss: {}, train acc: {},'.format(train_loss, train_acc),
                'test loss: {}, test acc: {}.'.format(test_loss, test_acc),
            )

    def train(self):

        train_loss = Average()
        train_acc = Accuracy()

        for data, target in self.train_loader:
            with dist_autograd.context() as context_id:
                data_ref = RRef(data)

                output_ref = self.model(data_ref)
                output = output_ref.to_here()
                loss = F.cross_entropy(output, target)

                dist_autograd.backward([loss])
                self.optimizer.step()

                train_loss.update(loss.item(), data.size(0))
                train_acc.update(output, target)

        return train_loss, train_acc

    def evaluate(self):
        self.model.eval()

        test_loss = Average()
        test_acc = Accuracy()

        with torch.no_grad():
            for data, target in self.test_loader:
                with dist_autograd.context() as context_id:
                    data_ref = RRef(data)

                    output_ref = self.model(data_ref)
                    output = output_ref.to_here()
                    loss = F.cross_entropy(output, target)

                    test_loss.update(loss.item(), data.size(0))
                    test_acc.update(output, target)

        return test_loss, test_acc

At the top level, I created a CNNModel, initialized the RPC framework and passed the Distributed Optimizer to the trainer:

**#Worker 0**

def run(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    model = CNNModel(args['host'], args['worker'],device)
    
    # setup distributed optimizer
    opt = DistributedOptimizer(
        optim.Adam,
        model.parameter_rrefs(),
        lr=args['lr'],
    )

    train_loader = MNISTDataLoader(args['root'], args['batch_size'], train=True)
    test_loader = MNISTDataLoader(args['root'], args['batch_size'], train=False)

    trainer = Trainer(model, opt, train_loader, test_loader, device)
    trainer.fit(args['epochs'])

def main():
    argv = {'world_size': int(2),
            'rank': int(0),
            'host': "worker0",
            'worker': "worker1",
            'epochs': int(10),
            'lr': float(1e-3),
            'root': 'data',
            'batch_size': int(32)
           }
    
    print(argv)
    rpc.init_rpc(argv['host'], rank=argv['rank'], world_size=argv['world_size'])
    print('Start Run', argv['rank'])
    run(argv)
    rpc.shutdown()

os.environ['MASTER_ADDR'] = '10.142.0.13'#Google Cloud
#os.environ['MASTER_ADDR'] = 'localhost' #local
os.environ['MASTER_PORT'] = '29505'
main()
**#Worker 1**

def main():
    argv = {'world_size': int(2),
            'rank': int(1),
            'host': 'worker0',
            'worker': 'worker1',
            'epochs': int(10),
            'lr': float(1e-3),
            'root': 'data',
            'batch_size': int(32)
           }
    
    print(argv)
    rpc.init_rpc(argv['worker'], rank=argv['rank'], world_size=argv['world_size'])
    print('Start Run', argv['rank'])
    rpc.shutdown()

os.environ['MASTER_ADDR'] = '10.142.0.13'#Google Cloud
#os.environ['MASTER_ADDR'] = 'localhost' #local
os.environ['MASTER_PORT'] = '29505'
main()

The dataloader is as the following:

from torch.utils import data
from torchvision import datasets, transforms

class MNISTDataLoader(data.DataLoader):

    def __init__(self, root, batch_size, train=True):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])

        dataset = datasets.FashionMNIST(root, train=train, transform=transform, download=True)
        

        super(MNISTDataLoader, self).__init__(
            dataset,
            batch_size=batch_size,
            shuffle=train,
        )

I showed all the details above but I guess the problem could be the way I constructed the CNNModel for the ConvNet and FCNet. I wonder if you could take a look at the code and give some hints on where could be the problems?

Thank you very much for your time!

Best,
Ziyi

How fast does the memory usage increase? Does it keeps increasing after every epoch or stabilized after a few epoches?

It could be due to RRef or distributed autograd context wasn’t deleted in time. It might worth provide an API to block waiting for all RPC workers to clear RRefs and dist autograd contexts. cc @pritamdamania87

Sorry, what I wrote in the previous post was not clear. The GPU usage keeps growing while you are training not necessarily having a relation to the epochs. I took a closer look at the GPU memory usage. On worker #1, I kept using

nvidia-smi

while the program is running and the memory usage kept increasing. The original GPU memory usage should be small and around 500MB. But if I use the RPC it will keep growing, and every second I typed nvidia-smi and I can see a few MB increased and if it is between epochs, I can also see a big jump of increment of memory usage.

Best,
Ziyi

Looks like there is a leak.

For RRef, there is an internal API _rref_context_get_debug_info to check the number of living OwerRRefs (with key num_owner_rrefs). Here are examples in tests.

Similarly, distributed autograd can also show the number of living backward passes. [example]

BTW, which version of PyTorch are you using? I recall we fixed some memory leaks after v1.4 release, e.g., this PR.

Hi @mrshenli,

I am using Pytorch 1.4.0 installed by the following command on the official website:

conda install pytorch torchvision cudatoolkit=10.1 -c pytorch

Sure, I will check the posts and see if I can figure out the problem by myself. Please let me know if the version is a problem, and it is much appreciated if you could share more thoughts on the code I provided above.

Thank you!

Best,
Ziyi

Please let me know if the version is a problem

Yes, the memory leak could be fixed by new PRs landed after v1.4 (v1.4 branch cut was 12/08/2019 IIRC). Can you try the master branch, or the nightly (conda install pytorch torchvision -c pytorch-nightly), or if you can share a self-runnable script, we can try that too.

if you could share more thoughts on the code I provided above.

The code you share above looks correct to me.

Hi @mrshenli,

Yes, I just tested my code with the PyTorch-nightly version and it does not have the memory leakage issue anymore. Also, the syntax is more similar to what you have shown in the tutorial.

Thank you very much for your help!

Best,
Ziyi

1 Like

Awesome! Thanks for sharing the result!

Hi @mrshenli,

Sorry, my test back then was okay. However, I just created a new instance on Google Cloud and install the lastest Pytorch-nightly then it raises up this error during training:

I am here to provide the code for the RPC training. Please take a look if you would have time!

Thank you.

IIUC, this is caused by a race recently introduced in a recent PR. @pritamdamania87 has a fix for that. Btw, this bug is not in v1.5, if you can install v1.5 RC or wait for that fix to land in nightly, this error should disappear.

Thank you very much for your information and response. Sure I will wait for the fix and the new release.

Best,
Ziyi