Rref can only be used once

hi,
I implement a RPC distributed model like the tutorial example(rnn part)

I put an embedding module locally and a pretrained bert model remotely. The whole model is as follows:

class DistCPM(nn.Module):
    def __init__(self,
                 save_prompt_embedding_path,
                 device="cuda:0",
                 ps="ps"):
        super(DistCPM, self).__init__()
        # setup prompt locally(I want to add some prompt tokens to improve the accuracy)
        self.prompt_embedding_policy = PromptEmbeddingPolicy(
            prompt_embeds_path=save_prompt_embedding_path,
            device=device
        )
        # setup bert remotely
        # In addition to create a LanguageModelService model(e.g bert/bart), I also load its checkpoint. This will take a while
        # And the requires_grad is set to false to make the remote model has no grad.
        self.language_model_service_rref = rpc.remote(ps, LanguageModelService, args=())

    def forward(self,
                word_embeds,
                pad_emnbeds,
                batch_target_sequence_token):
        # make input embeddings and so on
        model_batch, _ = self.prompt_embedding_policy(
            word_embeds, pad_emnbeds, batch_target_sequence_token
        )
        logits = _remote_method(LanguageModelService.forward,
                                self.language_model_service_rref,
                                model_batch)
        return logits
    # return all params(local+remote) just as the example
    def parameter_rrefs(self):
        remote_params = []
        # create RRefs for local parameters
        remote_params.extend(_parameter_rrefs(self.prompt_embedding_policy))
        # get RRefs of bert(or bart and so on.Given I set bert.requires_grad=False, the following line may not be necessary)
        remote_params.extend(_remote_method(_parameter_rrefs, self.language_model_service_rref))
        return remote_params

# trainer, also like that in the example
def _run_trainer():
    r"""
    The trainer creates a distributed RNNModel and a DistributedOptimizer. Then,
    it performs training using random input data.
    """
    train_dataset_path = 'data/diagnose/train_1000.jsonl'
    test_dataset_path = 'data/diagnose/test.jsonl'
    save_prompt_embedding_path = 'data/checkpoint/embedding.pth'  # if exists, will be load

    device0 = "cuda:0"

    model = DistCPM(save_prompt_embedding_path, device=device0, ps="ps") 

    criterion = torch.nn.CrossEntropyLoss()

    opt = DistributedOptimizer(
        optim.SGD,
        model.parameter_rrefs(),
        lr=0.05,
    )

    # load bert word embeddings
    # I will concate the prompt tokens embedding and bert-tokens embedding
    word_embedding_service = WordEmbeddingService(device0)

    batch_size = 4
    batch_train_dataset = load_dataset(train_dataset_path, batch_size)

    for k in range(100):
        for bix, batch in enumerate(batch_train_dataset):
            print(f"epoch-{k} batch-{bix}")
            with dist_autograd.context() as context_id:
                batch_information = {'special_vocab_string_list': ['<pad>', '<mask>'],
                                     'batch_source_sequence_string': batch['source'],
                                     'batch_target_sequence_string': batch['target']}
                
                # The next few lines create input embeddings.
                batch_information = word_embedding_service.get_word_embedding(batch_information)
                word_embeds = batch_information['batch_sequence_embedding']
                pad_emnbeds = batch_information['special_vocab_embedding_list'][0]
                batch_target_sequence_token = batch_information['batch_target_sequence_token']
                
                # logits
                logits = model(word_embeds, pad_emnbeds, batch_target_sequence_token)
                logits = logits.float()

                target = batch_target_sequence_token

                loss = 0.0
                target_token_len = []
                # calculate loss trivially
                for index in range(len(target)):
                    if logits[index].shape[0] != target[index].shape[0]:
                        pass
                    else:
                        loss += criterion(logits[index], target[index])
                    target_token_len.append(target[index].shape[0])
                loss = loss / sum(target_token_len)
                print('batch_train_loss:', loss)

                # run distributed backward pass
                dist_autograd.backward(context_id, [loss])
                # run distributed optimizer
                opt.step(context_id)

the error is: RuntimeError: RRef creation via rpc.remote() timed out, and it is possible that the RRef on the owner node does not exist.

I find the RREF object(namely self.language_model_service_rref) only can be used once.
The first usage occurs in optimizer:

opt = DistributedOptimizer(
        optim.SGD,
        # [RRef(prompt_embedding_policy.prompt_embeds)],
        model.parameter_rrefs(),  # get the remote params will use self.language_model_service_rref)
        lr=0.05,
    )

The second usage is in model(*args) where the error occurs:

logits = model(word_embeds, pad_emnbeds, batch_target_sequence_token)

# model.forward()
def forward(self,
                word_embeds,
                pad_emnbeds,
                batch_target_sequence_token):
        # make input embeddings and so on
        model_batch, _ = self.prompt_embedding_policy(
            word_embeds, pad_emnbeds, batch_target_sequence_token
        )
        logits = _remote_method(LanguageModelService.forward,  # the second usage
                                self.language_model_service_rref,
                                model_batch)  

Any comment would be appreciated.

Does this mean that this error won’t occur if you don’t use DistributedOptimizer?

cc: @mrshenli

In the tutorial, the distributedOptimizer is necessary.

On the other hand, even the rref object(self.language_model_service_rref) is only used in the model(used in model.forward() exactly), the same error occurs when the second batch begins to be handled.

# remove the rref object from optimizer
opt = DistributedOptimizer(
        optim.SGD,

        # Just get the local params and wrap it with RRef,
        # because I need no update to the remote bert params
        [RRef(prompt_embedding_policy.params)],  

        # model.parameter_rrefs(),  # get the remote params, will use self.language_model_service_rref)
        lr=0.05,
    )
...
# when the model begins to deal with the second batch, the same error occurs(the rref object does not exist).
for batch in data_loader:
    logits = model(batch)  # note just an example here

hi @111282, when you create self.language_model_service_rref it takes sometime as you mentioned, but the result rref is returned to you immediately, although the underlying remote object is not created yet.
This creation takes more time than is allowed by the default timeout and thus it fails silently. But when you try to call _remote_method(LanguageModelService.forward, self.language_model_service_rref,...) this error is shown to you. This behavior is described in remote method timeout argument:

timeout in seconds for this remote call. If the creation of this RRef on worker to is not successfully processed on this worker within this timeout, then the next time there is an attempt to use the RRef (such as to_here() ), a timeout will be raised indicating this failure.

What you need to do is to add some timeout greater than default one(60) here:

self.language_model_service_rref = rpc.remote(ps, LanguageModelService, args=(), timeout=...)
2 Likes

@pbelevich You are right. But I am still wondering why the error occurs on the second use?

The above error occurs when the second usage(namely model(*args)) happens. The first usage is getting the remote params when creating DistributedOptimizer.

Also even I remove the first usage(code as follows), the same error occurs when the model begins to deal with the second batch(The first batch is ok). So I thought the rref object can only be used once.

# remove the rref object from optimizer
opt = DistributedOptimizer(
        optim.SGD,

        # Just get the local params and wrap it with RRef,
        # because I need no update to the remote bert params
        # and bert.requires_grad = False
        [RRef(prompt_embedding_policy.params)],  

        # model.parameter_rrefs(),  # get the remote params, will use self.language_model_service_rref)
        lr=0.05,
    )
...
# Now the same error occurs when the second batch begins to be dealt with.
for batch in data_loader:
    logits = model(batch)  # note just an example here