Load_state_dict returns 'object is not subscriptable'

Hi everyone,
I have the following model (from transformers):

class ROBERTAClass(torch.nn.Module):
    def __init__(self):
        super(ROBERTAClass, self).__init__()
        self.l1 = RobertaForTokenClassification.from_pretrained('roberta-base', num_labels= number_of_outputs)
      
    
    def forward(self, ids, mask, label):
        outputs= self.l1(ids, attention_mask = mask, labels=label)


        #return torch.nn.Softmax(1)( outputs.logits)
        # return torch.sigmoid(outputs.logits)
        return outputs

and the following save and load functions


def save_ckp(state, checkpoint_path):

    f_path = checkpoint_path

    torch.save(state, f_path)

def load_ckp(checkpoint_fpath, model, optimizer):

    checkpoint = torch.load(checkpoint_fpath)
  
    model.load_state_dict(checkpoint['state_dict'])

    optimizer.load_state_dict(checkpoint['optimizer'])
 
    valid_loss_min = checkpoint['valid_loss_min']
   
    return model
   
   

The issue is that if I save the model and then load it I get the following error on the call of model.load_state_dict(checkpoint[‘state_dict’]): ROBERTAClass object is not is not subscriptable.
Do you know the reason for this?
Thanks!

Based on the error I guess you are not saving the state_dict (as would be the recommended way) but the model directly while assuming that checkpoint is a dict containing several internal state_dicts.
Here is a small example showing the issue:

model = nn.Linear(1, 1)

# save model directly
torch.save(model, "tmp.pt")
checkpoint = torch.load("tmp.pt")
checkpoint['state_dict']
# TypeError: 'Linear' object is not subscriptable

Thanks so much, that’s it. Such a silly mistake :sweat_smile: