Load model in eval model is normal,but load model to continue train get error

I have trained a model by 8 GPU,and success generate the model file using

torch.save(model.state_dict(),“a.model”)

Then I use the “a.model” for two use.
First I use the “a.model” to predict some sample.The code can run correct with no error. The code are like below:

model = LstmAttentionModel( )
model.cuda()
model = nn.DataParallel(model)
state_dict = torch.load('a.model')
from collections import OrderedDict
new_state_dict = OrderedDict()
# this code solved the key errors,when load model
for k, v in state_dict.items():
	if 'module' not in k:
		k = 'module.'+k
	else:
		k = k.replace('features.module.', 'module.features.')
	new_state_dict[k]=v
model.load_state_dict(new_state_dict)
model.eval()
result = model()

I can load the model correctly with no error,and get the “result = model()” correctly.

So I use the ‘a.model’ for a second use.In this case,I want to continue to train the ‘a.model’ ,not from the start to train, something like a finetune step.The code is very similar with the code in my first use case.

model = LstmAttentionModel()
model.cuda()
model = nn.DataParallel(model)
state_dict = torch.load('a.model')
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    if 'module' not in k:
        k = 'module.'+k
    else:
        k = k.replace('features.module.', 'module.features.')
    new_state_dict[k]=v
model.load_state_dict(new_state_dict)

train_dataset = CustomDataset("../data/all_sample/202109", batch_size,use_cpus,train_file_num)
val_dataset = CustomDataset("../data/eval", batch_size,use_cpus,eval_file_num)
train_loader = DataLoader(train_dataset, batch_size=batch_size,  num_workers=use_cpus,drop_last=True,pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=use_cpus, drop_last=True,pin_memory=True)
best_acc = 0
for epoch in range(epochs):
    train_loss, train_acc = train_model(model, train_loader, epoch)
    print("train end,begin eval")
    val_loss, val_acc = eval_model(model, val_loader)
    print("eval end")
    if val_acc >best_acc:
        print(
        f'Epoch: {epoch + 1:02}, Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%, Val. Loss: {val_loss:3f}, Val. Acc: {val_acc:.2f}%')
        best_acc = val_acc
        torch.save(model.state_dict(), './model/best-model-finalacc'+str(best_acc)+'-time'+str(time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()))+".param")

But I get the error like below.

Traceback (most recent call last):
  File "main_sogou_continue_train.py", line 196, in <module>
    train_loss, train_acc = train_model(model, train_loader, epoch)
  File "main_sogou_continue_train.py", line 79, in train_model
    query_emb, doc_emb, all_score, pos_score = model(doc, query, dominid,siteid,query_len,doc_len)
  File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 161, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 171, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/_utils.py", line 428, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 161, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 171, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/_utils.py", line 428, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/search/odin/ningshiqi/clickRecommend/src_continue_train/model_sogou.py", line 139, in forward
    query_embedding = self.word_embeddings(query)
  File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/modules/sparse.py", line 126, in forward
    self.norm_type, self.scale_grad_by_freq, self.sparse)
  File "/search/odin/bateer/tools/anaconda2/envs/shiqi/lib/python3.6/site-packages/torch/nn/functional.py", line 1852, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Input, output and indices must be on the current device

In my two cases, the code loding model is totally same,but in the second case I get the error.The error says “RuntimeError: Input, output and indices must be on the current device”.Why loading model’s code is same ,but two different result?

Are you seeing the same error if you load the state_dict to the model before wrapping it into nn.DataParallel?

Thanks for your reply! According your suggest ,I change the order between nn.DataParallel and load( state_dict).

from

model = LstmAttentionModel()
model.cuda()
model = nn.DataParallel(model)
state_dict = torch.load('a.model')
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    if 'module' not in k:
        k = 'module.'+k
    else:
        k = k.replace('features.module.', 'module.features.')
    new_state_dict[k]=v
model.load_state_dict(new_state_dict)

to

model = LstmAttentionModel()
model.cuda()
state_dict = torch.load('a.model')
model.load_state_dict(state_dict)
model = nn.DataParallel(model)

But it get the totally same error

I also try the

model = LstmAttentionModel()

state_dict = torch.load('a.model')
model.load_state_dict(state_dict)
model.cuda()
model = nn.DataParallel(model)

Still the same ERROR

I found the answer.I call model = nn.DataParallel(model) twice…