My guess is, you used torch.save(model)
(saves the entire model with architecture and weights and probably other parameters) instead of torch.save(model.state_dict())
(saves only the state dict, i.e. the weights).
If you need to stick to the first approach, then you must load it as follows: model = torch.load(PATH)
.