I change the code, and it works, thankyou.
class WrappedModel(nn.Module):
def __init__(self, module):
super(WrappedModel, self).__init__()
self.module = module # that I actually define.
def forward(self, x):
return self.module(x)
model = getattr(models, args.model)(args)
model = WrappedModel(model)
state_dict = torch.load(modelname)['state_dict']
model.load_state_dict(state_dict)