Loading weights from DataParallel models

I change the code, and it works, thankyou.

class WrappedModel(nn.Module):
	def __init__(self, module):
		super(WrappedModel, self).__init__()
		self.module = module # that I actually define.
	def forward(self, x):
		return self.module(x)

model = getattr(models, args.model)(args)
model = WrappedModel(model)
state_dict = torch.load(modelname)['state_dict']
model.load_state_dict(state_dict)
4 Likes