just do this:
model = torch.load(train_model) … net.load_state_dict(model[‘state_dict’])
it works for me!