Hi everyone,
I have the following model (from transformers):
class ROBERTAClass(torch.nn.Module):
def __init__(self):
super(ROBERTAClass, self).__init__()
self.l1 = RobertaForTokenClassification.from_pretrained('roberta-base', num_labels= number_of_outputs)
def forward(self, ids, mask, label):
outputs= self.l1(ids, attention_mask = mask, labels=label)
#return torch.nn.Softmax(1)( outputs.logits)
# return torch.sigmoid(outputs.logits)
return outputs
and the following save and load functions
def save_ckp(state, checkpoint_path):
f_path = checkpoint_path
torch.save(state, f_path)
def load_ckp(checkpoint_fpath, model, optimizer):
checkpoint = torch.load(checkpoint_fpath)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
valid_loss_min = checkpoint['valid_loss_min']
return model
The issue is that if I save the model and then load it I get the following error on the call of model.load_state_dict(checkpoint[‘state_dict’]): ROBERTAClass object is not is not subscriptable.
Do you know the reason for this?
Thanks!