TypeError: get_storage_from_record(): incompatible function arguments

hi I’ve run into this issue when i try to load my weights using model.net_load_state_dict(torch.load(PATH))
this is the model that I use


class Interpolate(nn.Module):
    def __init__(self, scale_factor, mode):
        super(Interpolate, self).__init__()
        self.interp = nn.functional.interpolate
        self.scale_factor = scale_factor
        self.mode = mode
    def forward(self, x):
        x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False)
        return x
class G(nn.Module):
    def __init__(self):
        super(G, self).__init__()

        self.fc2 = nn.Linear(18, 128)
        self.relu = nn.ReLU(True)
        self.fc = nn.Linear(512+128, 64*256*256)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv1 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=(4,4), stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=18, out_channels=256, kernel_size=(4,4), stride=1, padding=1)

        self.main = nn.Sequential(
            
                nn.Conv2d(in_channels=640, out_channels=2048, kernel_size=(3,3), stride=1, bias=True, padding=1), ##1x1
                nn.BatchNorm2d(2048),
                nn.ReLU(True),

                Interpolate(scale_factor=2, mode='bilinear'),
                nn.Conv2d(in_channels=2048, out_channels=2048,  kernel_size=(3,3), stride=1, bias=True, padding=1), ##2x2
                nn.BatchNorm2d(2048),
                nn.ReLU(True),

                Interpolate(scale_factor=2, mode='bilinear'),
                nn.Conv2d(in_channels=2048, out_channels=2048,  kernel_size=(3,3), stride=1, bias=True, padding=1), ##2x2
                nn.BatchNorm2d(2048),
                nn.ReLU(True),
                Interpolate(scale_factor=2, mode='bilinear'),
                nn.Conv2d(in_channels=2048, out_channels=1024,  kernel_size=(3,3), stride=1, bias=True, padding=1), ##4x4
                nn.BatchNorm2d(1024),
                nn.ReLU(True),
                Interpolate(scale_factor=2, mode='bilinear'),
                nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=(3,3), stride=1, bias=True,padding=1), ## 8x8
                nn.BatchNorm2d(1024),
                nn.ReLU(True),
                Interpolate(scale_factor=2, mode='bilinear'),
                nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=(3,3), stride=1, bias=True,padding=1), ## 16x16
				nn.BatchNorm2d(512),
                nn.ReLU(True),
                Interpolate(scale_factor=2, mode='bilinear'),
                nn.Conv2d(in_channels=512, out_channels=256,  kernel_size=(3,3), stride=1, bias=True,padding=1), ## 32x32
				nn.BatchNorm2d(256),
                nn.ReLU(True),	
                Interpolate(scale_factor=2, mode='bilinear'),
                nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(3,3), stride=1, bias=True,padding=1), ## 64x64
				nn.BatchNorm2d(128),
                nn.ReLU(True),		
                Interpolate(scale_factor=2, mode='bilinear'),
                nn.Conv2d(in_channels=128, out_channels=64, kernel_size=(3,3), stride=1, bias=True,padding=1), ## 128x128
				nn.BatchNorm2d(64),
                nn.ReLU(True),
				nn.Conv2d(in_channels=64, out_channels=3, kernel_size=(1,1), stride=1, bias=True),
                nn.Tanh()
                )
    def forward(self, x, attr):
        gen_input = torch.cat((self.fc2(attr), x), -1).unsqueeze(2).unsqueeze(2)
 
        img = self.main(gen_input)
        img = img.view(img.size(0), *img_shape)
        return img

Traceback (most recent call last):
File “eval.py”, line 28, in
netG.load_state_dict(torch.load(‘generator_16.pt’), strict=False)
File “C:\Users\laudw\anaconda3\envs\seg\lib\site-packages\torch\serialization.py”, line 594, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
File “C:\Users\laudw\anaconda3\envs\seg\lib\site-packages\torch\serialization.py”, line 853, in _load
result = unpickler.load()
File “C:\Users\laudw\anaconda3\envs\seg\lib\site-packages\torch\serialization.py”, line 845, in persistent_load
load_tensor(data_type, size, key, _maybe_decode_ascii(location))
File “C:\Users\laudw\anaconda3\envs\seg\lib\site-packages\torch\serialization.py”, line 833, in load_tensor
storage = zip_file.get_storage_from_record(name, size, dtype).storage()
TypeError: get_storage_from_record(): incompatible function arguments. The following argument types are supported:
1. (self: torch._C.PyTorchFileReader, arg0: str, arg1: int, arg2: object) → at::Tensor

Invoked with: <torch._C.PyTorchFileReader object at 0x00000142AFC786F8>, ‘data/2094511113808’, -1610612736, torch.float32

Based on your description you might be hitting this Python bug on Windows. Could you check if the posted workaround would work?

I tried changing the pickle protocol to 4 in torch.load, and the serialization.py file but it still returns the same error