hi,
I implement a RPC distributed model like the tutorial example(rnn part)
I put an embedding module locally and a pretrained bert model remotely. The whole model is as follows:
class DistCPM(nn.Module):
def __init__(self,
save_prompt_embedding_path,
device="cuda:0",
ps="ps"):
super(DistCPM, self).__init__()
# setup prompt locally(I want to add some prompt tokens to improve the accuracy)
self.prompt_embedding_policy = PromptEmbeddingPolicy(
prompt_embeds_path=save_prompt_embedding_path,
device=device
)
# setup bert remotely
# In addition to create a LanguageModelService model(e.g bert/bart), I also load its checkpoint. This will take a while
# And the requires_grad is set to false to make the remote model has no grad.
self.language_model_service_rref = rpc.remote(ps, LanguageModelService, args=())
def forward(self,
word_embeds,
pad_emnbeds,
batch_target_sequence_token):
# make input embeddings and so on
model_batch, _ = self.prompt_embedding_policy(
word_embeds, pad_emnbeds, batch_target_sequence_token
)
logits = _remote_method(LanguageModelService.forward,
self.language_model_service_rref,
model_batch)
return logits
# return all params(local+remote) just as the example
def parameter_rrefs(self):
remote_params = []
# create RRefs for local parameters
remote_params.extend(_parameter_rrefs(self.prompt_embedding_policy))
# get RRefs of bert(or bart and so on.Given I set bert.requires_grad=False, the following line may not be necessary)
remote_params.extend(_remote_method(_parameter_rrefs, self.language_model_service_rref))
return remote_params
# trainer, also like that in the example
def _run_trainer():
r"""
The trainer creates a distributed RNNModel and a DistributedOptimizer. Then,
it performs training using random input data.
"""
train_dataset_path = 'data/diagnose/train_1000.jsonl'
test_dataset_path = 'data/diagnose/test.jsonl'
save_prompt_embedding_path = 'data/checkpoint/embedding.pth' # if exists, will be load
device0 = "cuda:0"
model = DistCPM(save_prompt_embedding_path, device=device0, ps="ps")
criterion = torch.nn.CrossEntropyLoss()
opt = DistributedOptimizer(
optim.SGD,
model.parameter_rrefs(),
lr=0.05,
)
# load bert word embeddings
# I will concate the prompt tokens embedding and bert-tokens embedding
word_embedding_service = WordEmbeddingService(device0)
batch_size = 4
batch_train_dataset = load_dataset(train_dataset_path, batch_size)
for k in range(100):
for bix, batch in enumerate(batch_train_dataset):
print(f"epoch-{k} batch-{bix}")
with dist_autograd.context() as context_id:
batch_information = {'special_vocab_string_list': ['<pad>', '<mask>'],
'batch_source_sequence_string': batch['source'],
'batch_target_sequence_string': batch['target']}
# The next few lines create input embeddings.
batch_information = word_embedding_service.get_word_embedding(batch_information)
word_embeds = batch_information['batch_sequence_embedding']
pad_emnbeds = batch_information['special_vocab_embedding_list'][0]
batch_target_sequence_token = batch_information['batch_target_sequence_token']
# logits
logits = model(word_embeds, pad_emnbeds, batch_target_sequence_token)
logits = logits.float()
target = batch_target_sequence_token
loss = 0.0
target_token_len = []
# calculate loss trivially
for index in range(len(target)):
if logits[index].shape[0] != target[index].shape[0]:
pass
else:
loss += criterion(logits[index], target[index])
target_token_len.append(target[index].shape[0])
loss = loss / sum(target_token_len)
print('batch_train_loss:', loss)
# run distributed backward pass
dist_autograd.backward(context_id, [loss])
# run distributed optimizer
opt.step(context_id)
the error is: RuntimeError: RRef creation via rpc.remote() timed out, and it is possible that the RRef on the owner node does not exist.
I find the RREF object(namely self.language_model_service_rref) only can be used once.
The first usage occurs in optimizer:
opt = DistributedOptimizer(
optim.SGD,
# [RRef(prompt_embedding_policy.prompt_embeds)],
model.parameter_rrefs(), # get the remote params will use self.language_model_service_rref)
lr=0.05,
)
The second usage is in model(*args) where the error occurs:
logits = model(word_embeds, pad_emnbeds, batch_target_sequence_token)
# model.forward()
def forward(self,
word_embeds,
pad_emnbeds,
batch_target_sequence_token):
# make input embeddings and so on
model_batch, _ = self.prompt_embedding_policy(
word_embeds, pad_emnbeds, batch_target_sequence_token
)
logits = _remote_method(LanguageModelService.forward, # the second usage
self.language_model_service_rref,
model_batch)
Any comment would be appreciated.