I have trained a model by 8 GPU,and success generate the model file using
torch.save(model.state_dict(),“a.model”)
Then I use the “a.model” for two use.
First I use the “a.model” to predict some sample.The code can run correct with no error. The code are like below:
model = LstmAttentionModel( )
model.cuda()
model = nn.DataParallel(model)
state_dict = torch.load('a.model')
from collections import OrderedDict
new_state_dict = OrderedDict()
# this code solved the key errors,when load model
for k, v in state_dict.items():
if 'module' not in k:
k = 'module.'+k
else:
k = k.replace('features.module.', 'module.features.')
new_state_dict[k]=v
model.load_state_dict(new_state_dict)
model.eval()
result = model()
I can load the model correctly with no error,and get the “result = model()” correctly.
So I use the ‘a.model’ for a second use.In this case,I want to continue to train the ‘a.model’ ,not from the start to train, something like a finetune step.The code is very similar with the code in my first use case.
model = LstmAttentionModel()
model.cuda()
model = nn.DataParallel(model)
state_dict = torch.load('a.model')
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if 'module' not in k:
k = 'module.'+k
else:
k = k.replace('features.module.', 'module.features.')
new_state_dict[k]=v
model.load_state_dict(new_state_dict)
train_dataset = CustomDataset("../data/all_sample/202109", batch_size,use_cpus,train_file_num)
val_dataset = CustomDataset("../data/eval", batch_size,use_cpus,eval_file_num)
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=use_cpus,drop_last=True,pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=use_cpus, drop_last=True,pin_memory=True)
best_acc = 0
for epoch in range(epochs):
train_loss, train_acc = train_model(model, train_loader, epoch)
print("train end,begin eval")
val_loss, val_acc = eval_model(model, val_loader)
print("eval end")
if val_acc >best_acc:
print(
f'Epoch: {epoch + 1:02}, Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%, Val. Loss: {val_loss:3f}, Val. Acc: {val_acc:.2f}%')
best_acc = val_acc
torch.save(model.state_dict(), './model/best-model-finalacc'+str(best_acc)+'-time'+str(time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()))+".param")
But I get the error like below.
Traceback (most recent call last):
File "main_sogou_continue_train.py", line 196, in <module>
train_loss, train_acc = train_model(model, train_loader, epoch)
File "main_sogou_continue_train.py", line 79, in train_model
query_emb, doc_emb, all_score, pos_score = model(doc, query, dominid,siteid,query_len,doc_len)
File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 161, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 171, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
output.reraise()
File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/_utils.py", line 428, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 161, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 171, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
output.reraise()
File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/_utils.py", line 428, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/search/odin/ningshiqi/clickRecommend/src_continue_train/model_sogou.py", line 139, in forward
query_embedding = self.word_embeddings(query)
File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/modules/sparse.py", line 126, in forward
self.norm_type, self.scale_grad_by_freq, self.sparse)
File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/functional.py", line 1852, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Input, output and indices must be on the current device
In my two cases, the code loding model is totally same,but in the second case I get the error.The error says “RuntimeError: Input, output and indices must be on the current device”.Why loading model’s code is same ,but two different result?