def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
num_classes=91, pretrained_backbone=True, trainable_backbone_layers=3, **kwargs):
assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0
# dont freeze any layers if pretrained model or backbone is not used
if not (pretrained or pretrained_backbone):
trainable_backbone_layers = 5
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes, **kwargs)
if pretrained:
### update faster rcnn pretrained model
pretrained_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'],
progress=progress)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# update & load
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
return model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
pretrained=False, image_mean=image_mean, image_std=image_std)
model.backbone.body.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
I’ve custom the resetnet to 4 channels and would like to load the pre-trained dict into my custom model. But the load_dict method seems not working. Is there any advice for this? thanks