RuntimeError: Error(s) in loading state_dict for ResNet: Missing key(s) in state_dict: "ca.fc1.weight", "ca.fc2.weight", "sa.conv1.weight", "ca1.fc1.weight", "ca1.fc2.weight", "sa1.conv1.weight"

I modified the resnet50 network by:

These are the two classes:

class ChannelAttention(nn.Module):
def init(self, in_planes, ratio=16):
super(ChannelAttention, self).init()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)

    self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
    self.relu1 = nn.ReLU()
    self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

    self.sigmoid = nn.Sigmoid()

def forward(self, x):
    avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
    max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
    out = avg_out + max_out
    return self.sigmoid(out)

class SpatialAttention(nn.Module):
def init(self, kernel_size=7):
super(SpatialAttention, self).init()

    assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
    padding = 3 if kernel_size == 7 else 1

    self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
    self.sigmoid = nn.Sigmoid()

def forward(self, x):
    avg_out = torch.mean(x, dim=1, keepdim=True)
    max_out, _ = torch.max(x, dim=1, keepdim=True)
    x =[avg_out, max_out], dim=1)
    x = self.conv1(x)
    return self.sigmoid(x)

But when I run: model = models.resnet50(pretrained=True), it came out the problem, I have tried several solutions but still do not know how to fix it?

> RuntimeError                              Traceback (most recent call last)
> <ipython-input-32-1d0a3ace5f7c> in <module>
> ----> 1 model = models.resnet50(pretrained=True)
>       2 
>       3 
>       4 #freeze layers
>       5 # for param in model.parameters():
> D:\pythonana\envs\pytorch\lib\site-packages\torchvision\models\ in resnet50(pretrained, progress, **kwargs)
>     319         pretrained (bool): If True, returns a model pre-trained on ImageNet
>     320         progress (bool): If True, displays a progress bar of the download to stderr
> --> 321     """
>     322     return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
>     323                    **kwargs)
> D:\pythonana\envs\pytorch\lib\site-packages\torchvision\models\ in _resnet(arch, block, layers, pretrained, progress, **kwargs)
>     282         state_dict = load_state_dict_from_url(model_urls[arch],
>     283                                               progress=progress)
> --> 284         #model.load_state_dict(torch.load(PATH),strict=False)
>     285 
>     286         model.load_state_dict(state_dict)
> D:\pythonana\envs\pytorch\lib\site-packages\torch\nn\modules\ in load_state_dict(self, state_dict, strict)
>     828         if len(error_msgs) > 0:
>     829             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
> --> 830                                self.__class__.__name__, "\n\t".join(error_msgs)))
>     831         return _IncompatibleKeys(missing_keys, unexpected_keys)
>     832 
> RuntimeError: Error(s) in loading state_dict for ResNet:
> 	Missing key(s) in state_dict: "ca.fc1.weight", "ca.fc2.weight", "sa.conv1.weight", "ca1.fc1.weight", "ca1.fc2.weight", "sa1.conv1.weight".


You’ve modified ResNet50 and you are trying to load weights from pre-trained vanilla ResNet50 which is causing the error because the pre-trained network doesn’t have weights for your modified layers ChannelAttention and SpatialAttention. I guess you have made modifications to torchvision source code. Don’t make any changes to the source code and implement a custom network and then add weights to your custom model by iterating state_dict of pretrained ResNet50 from torchvision.

But I have to add the attention mechanism codes to chage the network and get the weights, so what can i do to solve it?

Check the implementation of ResNet50 from torchvision and implement custom network something like:

class ModifiedResNet50(nn.Module):
    def __init__(self,):
    def forward(input):

Then load weights from pretrained ResNet50 like given below:

vanilla_resnet = models.resnet50(pretrained=True)
modified_resnet = ModifiedResNet50()

for k,v in modified_resnet.state_dict():
    if k in vanilla_resnet.state_dict():
        modified_resnet.state_dict[k] = vanilla_resnet.state_dict[k]

Well, I add the class ModifiedResNet50 in the

class ModifiedResNet50(nn.Module):
def init(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
super(ResNet, self).init()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer

    self.inplanes = 64
    self.dilation = 1
    if replace_stride_with_dilation is None:
        # each element in the tuple indicates if we should replace
        # the 2x2 stride with a dilated convolution instead
        replace_stride_with_dilation = [False, False, False]
    if len(replace_stride_with_dilation) != 3:
        raise ValueError("replace_stride_with_dilation should be None "
                         "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
    self.groups = groups
    self.base_width = width_per_group
    self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
    self.bn1 = norm_layer(self.inplanes)
    self.relu = nn.ReLU(inplace=True)
     # 1st attention mechanism = ChannelAttention(self.inplanes) = SpatialAttention()
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    self.layer1 = self._make_layer(block, 64, layers[0])
    self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
    self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
    self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
  # 2nd attention mechanism
    self.ca1 = ChannelAttention(self.inplanes)
    self.sa1 = SpatialAttention()							   
    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.fc = nn.Linear(512 * block.expansion, num_classes)

    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    # Zero-initialize the last BN in each residual branch,
    # so that the residual branch starts with zeros, and each residual block behaves like an identity.
    # This improves the model by 0.2~0.3% according to
    if zero_init_residual:
        for m in self.modules():
            if isinstance(m, Bottleneck):
                nn.init.constant_(m.bn3.weight, 0)
            elif isinstance(m, BasicBlock):
                nn.init.constant_(m.bn2.weight, 0)

def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
    norm_layer = self._norm_layer
    downsample = None
    previous_dilation = self.dilation
    if dilate:
        self.dilation *= stride
        stride = 1
    if stride != 1 or self.inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            conv1x1(self.inplanes, planes * block.expansion, stride),
            norm_layer(planes * block.expansion),

    layers = []
    layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                        self.base_width, previous_dilation, norm_layer))
    self.inplanes = planes * block.expansion
    for _ in range(1, blocks):
        layers.append(block(self.inplanes, planes, groups=self.groups,
                            base_width=self.base_width, dilation=self.dilation,

    return nn.Sequential(*layers)

def _forward_impl(self, x):
    # See note [TorchScript super()]
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = * x
    x = * x
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x = self.ca1(x) * x
    x = self.sa1(x) * x

    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    x = self.fc(x)

    return x

def forward(self, x):
    return self._forward_impl(x)

but why the problem occurs?


it seems you are still editing source code, don’t add ModifiedResNet50 to You write/implement your custom class in your script/notebook cell. import bottleneck layers from torchvision if needed. Also change super(ResNet, self).init() to super(ModifiedResNet50, self).init(). I can’t write the whole code here, i can suggest a direction (so direct copy paste wont work here), rest you have to implement. Hope you understand.

Okay, thank you very much! I would try it.

Hello, I have modified the resnet_cbam successfully. And the loop should be:
for k in modified_resnet.state_dict():
if k in vanilla_resnet.state_dict():
modified_resnet.state_dict[k] = vanilla_resnet.state_dict[k]

Can you please comment on this post.