MyModel() object has no attribute 'load_state_dict'

Hi. I am trying to load a model to test it with:

model = MyModel()
model.load_state_dict(torch.load(PATH, map_location=torch.device('cpu')))

But it says :

'MyModel' object has no attribute 'load_state_dict'

What should I do ?

Did you derive MyModel from nn.Module or from another base class?
Make sure to derive it as:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        ...
1 Like