hi I’ve run into this issue when i try to load my weights using model.net_load_state_dict(torch.load(PATH))
this is the model that I use
class Interpolate(nn.Module):
def __init__(self, scale_factor, mode):
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
def forward(self, x):
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False)
return x
class G(nn.Module):
def __init__(self):
super(G, self).__init__()
self.fc2 = nn.Linear(18, 128)
self.relu = nn.ReLU(True)
self.fc = nn.Linear(512+128, 64*256*256)
self.bn1 = nn.BatchNorm2d(64)
self.conv1 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=(4,4), stride=1, padding=1)
self.conv2 = nn.Conv2d(in_channels=18, out_channels=256, kernel_size=(4,4), stride=1, padding=1)
self.main = nn.Sequential(
nn.Conv2d(in_channels=640, out_channels=2048, kernel_size=(3,3), stride=1, bias=True, padding=1), ##1x1
nn.BatchNorm2d(2048),
nn.ReLU(True),
Interpolate(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_channels=2048, out_channels=2048, kernel_size=(3,3), stride=1, bias=True, padding=1), ##2x2
nn.BatchNorm2d(2048),
nn.ReLU(True),
Interpolate(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_channels=2048, out_channels=2048, kernel_size=(3,3), stride=1, bias=True, padding=1), ##2x2
nn.BatchNorm2d(2048),
nn.ReLU(True),
Interpolate(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_channels=2048, out_channels=1024, kernel_size=(3,3), stride=1, bias=True, padding=1), ##4x4
nn.BatchNorm2d(1024),
nn.ReLU(True),
Interpolate(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=(3,3), stride=1, bias=True,padding=1), ## 8x8
nn.BatchNorm2d(1024),
nn.ReLU(True),
Interpolate(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=(3,3), stride=1, bias=True,padding=1), ## 16x16
nn.BatchNorm2d(512),
nn.ReLU(True),
Interpolate(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=(3,3), stride=1, bias=True,padding=1), ## 32x32
nn.BatchNorm2d(256),
nn.ReLU(True),
Interpolate(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(3,3), stride=1, bias=True,padding=1), ## 64x64
nn.BatchNorm2d(128),
nn.ReLU(True),
Interpolate(scale_factor=2, mode='bilinear'),
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=(3,3), stride=1, bias=True,padding=1), ## 128x128
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.Conv2d(in_channels=64, out_channels=3, kernel_size=(1,1), stride=1, bias=True),
nn.Tanh()
)
def forward(self, x, attr):
gen_input = torch.cat((self.fc2(attr), x), -1).unsqueeze(2).unsqueeze(2)
img = self.main(gen_input)
img = img.view(img.size(0), *img_shape)
return img
Traceback (most recent call last):
File “eval.py”, line 28, in
netG.load_state_dict(torch.load(‘generator_16.pt’), strict=False)
File “C:\Users\laudw\anaconda3\envs\seg\lib\site-packages\torch\serialization.py”, line 594, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
File “C:\Users\laudw\anaconda3\envs\seg\lib\site-packages\torch\serialization.py”, line 853, in _load
result = unpickler.load()
File “C:\Users\laudw\anaconda3\envs\seg\lib\site-packages\torch\serialization.py”, line 845, in persistent_load
load_tensor(data_type, size, key, _maybe_decode_ascii(location))
File “C:\Users\laudw\anaconda3\envs\seg\lib\site-packages\torch\serialization.py”, line 833, in load_tensor
storage = zip_file.get_storage_from_record(name, size, dtype).storage()
TypeError: get_storage_from_record(): incompatible function arguments. The following argument types are supported:
1. (self: torch._C.PyTorchFileReader, arg0: str, arg1: int, arg2: object) → at::Tensor
Invoked with: <torch._C.PyTorchFileReader object at 0x00000142AFC786F8>, ‘data/2094511113808’, -1610612736, torch.float32