import torch
import torch.nn as nn
class SomeModel(nn.Module):
def __init__(self, gpu_ids=[]):
super(SomeModel, self).__init__()
self.gpu_ids = gpu_ids
mean = torch.autograd.Variable(torch.Tensor([0.5, 0.5, 0.5]).view(1,3,1,1)).cuda(gpu_ids[0])
std = torch.autograd.Variable(torch.Tensor([0.5, 0.5, 0.5]).view(1,3,1,1)).cuda(gpu_ids[0])
self.register_buffer('mean', mean)
self.register_buffer('std', std)
def forward(self, input):
input = (input - self.mean) / self.std
if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.net, input, self.gpu_ids)
else:
return self.net(input)
When I run cmd in python interpreter, strange error occurs:
>>> from a import SomeModel
>>> model = SomeModel(gpu_ids=[1,2])
>>> m = model.state_dict()
>>> model.load_state_dict(m)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/data/image_server/extra/user/pangwong/anaconda2/envs/pytorch0.3.1/lib/python2.7/site-packages/torch/nn/modules/module.py", line 519, in load_state_dict
.format(name, own_state[name].size(), param.size()))
RuntimeError: While copying the parameter named mean, whose dimensions in the model are torch.Size([1, 3, 1, 1]) and whose dimensions in the checkpoint are torch.Size([1, 3, 1, 1]).
Environment:
- python2.7
- pytorch0.3.1