Hi all,
I have a sequence tagger built upon a Hugging Face’s bert based model.
class SequenceTagger(torch.nn.Module):
def __init__(self,bert_dim,n_labels,bert_model_dir,dp=0.5):
super(SequenceTagger,self).__init__()
self.bert = BertModel.from_pretrained(bert_model_dir)
self.out = torch.nn.Linear(bert_dim, n_labels)
self.dropout = nn.Dropout(dp)
self.log_softmax = nn.LogSoftmax(dim=2)
self.nll_loss = nn.NLLLoss(ignore_index=-1)
After I trained the model, I save it as follows:
model = SequenceTagger(.....)
model = train(model)
torch.save({
"epoch": i,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"best_avg_recall": current_avg_recall
}, args.model_save_dir.split)
And then I load it as follows:
model = SequenceTagger(bert_dim,len(self._l2ix),model_dir)
model.load_state_dict(torch.load(model_full_path)["model_state_dict"])
model.to(self._device)
model.eval()
It works smoothly if I do not change the directory of the saved model. However, if I copy the saved model somewhere else and use the copied path as argument to load function, I get
RuntimeError: Error(s) in loading state_dict for SequenceTagger:
Unexpected key(s) in state_dict:
error. When I checked the official docs, I see this is the recommended way for serialization. What could be the problem ?