Bug/issue loading pretrained fcos_resnet50_fpn

Hi, I am trying to load a pre-trained FCOS resnet50 object detector to be used with my custom dataset which has 5 classes:

# replace the classifier for 5 classes + background = 6 classes
self.num_classes = 6       

# import pretrained FCOS detector
self.model = torchvision.models.detection.fcos_resnet50_fpn(pretrained=True, num_classes=6)

Which gives the following error:

  File "C:\Users\lange\.conda\envs\deepl\lib\site-packages\torchvision\models\detection\fcos.py", line 704, in fcos_resnet50_fpn
    model.load_state_dict(state_dict)
  File "C:\Users\lange\.conda\envs\deepl\lib\site-packages\torch\nn\modules\module.py", line 1497, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for FCOS:
        size mismatch for head.classification_head.cls_logits.weight: copying a param with shape torch.Size([91, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([6, 256, 3, 3]).
        size mismatch for head.classification_head.cls_logits.bias: copying a param with shape torch.Size([91]) from checkpoint, the shape in current model is torch.Size([6]).

Is this a bug? I expected that setting num_classes to anything other than the default of 91 should work and the affected layers should be corrected?

Is there a workaround?

Thanks, Stefan

You won’t be able to use pretrained=True and change the number of classes directly, as the pretrained state_dict expects 91 classes for this model.
I get a proper ValueError as:

model = torchvision.models.detection.fcos_resnet50_fpn(pretrained=True, num_classes=2)
# ValueError: The parameter '2' expected value 91 but got 2 instead.

If you want to use the pretrained parameters and adapt the number of classes, load the original model first and adapt the model afterwards to return outputs for 6 classes only afterwards.

That’s really strange, I am not getting this Value Error. Which version are you running? I’m running PyTorch 1.11.0 and torchvision 0.12.0.

Could you please explain/point me in the right direction on how to do this?

I wrote this, I am testing it now:

# import model
model = torchvision.models.detection.fcos_resnet50_fpn(pretrained=True, num_classes=91)
classifiction_head = model.head.classification_head

# conv parameters
fout = classifiction_head.cls_logits

# create new layer from parameters
six_class_out = nn.Conv2d(in_channels=fout.in_channels, 
                        out_channels=6, # now with 6 classes
                        kernel_size=fout.kernel_size, 
                        stride=fout.stride, 
                        padding=fout.padding, 
                        dilation=fout.dilation, 
                        groups=fout.groups, 
                        padding_mode=fout.padding_mode,
                        device=fout.weight.device,
                        dtype=fout.weight.dtype)

# replace model head with new layer
model.head.classification_head.cls_logits = six_class_out
print(model)

This seems to work. I do wonder if there is a better way though.

I’m using 0.14.0.dev20220531+cu116 so the error message might have improved.

Yes, your approach looks correct.

For reference, if someone finds this thread later. There are surprisingly little guides that discuss transfer learning for object detectors and the implementation differs greatly per model type. So for reference, this implementation is great: https://github.com/LuisEstebanAcevedoBringas/FCOS_torch/blob/master/FCOS.py#L31

So the solution is:

    model = torchvision.models.detection.fcos_resnet50_fpn(pretrained=True)
    
    in_features = model.head.classification_head.conv[0].in_channels
    num_anchors = model.head.classification_head.num_anchors
    model.head.classification_head.num_classes = num_classes

    out_channels = 256

    cls_logits = torch.nn.Conv2d(out_channels, num_anchors * num_classes, kernel_size = 3, stride=1, padding=1)
    torch.nn.init.normal_(cls_logits.weight, std=0.01)
    torch.nn.init.constant_(cls_logits.bias, -math.log((1 - 0.01) / 0.01))

    model.head.classification_head.cls_logits = cls_logits

    print(model)
    return model

I am completely up and running now.