Yes, I used nn.DataParallel
. I didn’t understand your second suggestion. Loading the weights file, create a new ordered dict without the module prefix and load it back. (Can you provide an example?)
Are you suggesting something like this? (example taken from here - https://github.com/OpenNMT/OpenNMT-py/blob/master/train.py)
model_state_dict = model.module.state_dict() if len(opt.gpus) > 1 else model.state_dict()
model_state_dict = {k: v for k, v in model_state_dict.items() if 'generator' not in k}
generator_state_dict = model.generator.module.state_dict() if len(opt.gpus) > 1 else model.generator.state_dict()
# (4) drop a checkpoint
checkpoint = {
'model': model_state_dict,
'generator': generator_state_dict,
'dicts': dataset['dicts'],
'opt': opt,
'epoch': epoch,
'optim': optim
}
torch.save(checkpoint,
'%s_acc_%.2f_ppl_%.2f_e%d.pt' % (opt.save_model, 100*valid_acc, valid_ppl, epoch))
May I ask you one question about the above code snippet, what is generator
here?