FasterRCNN Mismatch: RuntimeError: Given groups=1, weight of size [256, 256, 3, 3], expected input[1, 512, 64, 64] to have 256 channels, but got 512 channels instead

Hi!

I’m working on a customized Faster RCNN to be used with 4 channel 512x512 images.

I’m running into a size mismatch when I try to run my data through, and I’m not sure why.

Here’s my code:

class CustomBackbone(torch.nn.Module):
    def __init__(self, backbone, out_channels):
        super().__init__()
        self.body = backbone
        self.out_channels = out_channels

    def forward(self, x):
        return self.body(x)

class ModifiedResNet50(resnet.ResNet):
'''
Modify ResNet to accept more channels and 512x512 images
'''
    def __init__(self, num_channels=4, pretrained=False):
        super().__init__(resnet.Bottleneck, [3, 4, 6, 3])
        
        self.conv1 = torch.nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        if not pretrained:
            self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, torch.nn.BatchNorm2d):
                torch.nn.init.constant_(m.weight, 1)
                torch.nn.init.constant_(m.bias, 0)

class ModifiedFeaturePyramidNetwork(feature_pyramid_network.FeaturePyramidNetwork):
    def __init__(self, in_channels_list, out_channels, extra_blocks=None):
        super().__init__(in_channels_list, out_channels, extra_blocks)
        
        # Modify the lateral layers to match the backbone output channels
        self.lateral_layers = torch.nn.ModuleList([
            torch.nn.Conv2d(in_channels, out_channels, 1) for in_channels in in_channels_list
        ])

resnet50 = ModifiedResNet50(num_channels=4)
backbone_body = torchvision.models._utils.IntermediateLayerGetter(resnet50, return_layers=return_layers)
backbone = CustomBackbone(backbone_body, out_channels)

fpn = ModifiedFeaturePyramidNetwork(
    in_channels_list=in_channels_list,
    out_channels=out_channels,
    extra_blocks=None,
)

model = FasterRCNN(backbone, num_classes=n_classes, rpn_anchor_generator=AnchorGenerator())

trfrm = torchvision.models.detection.transform.GeneralizedRCNNTransform(min_size=300, max_size=1000, 
                                                                        image_mean=[0,0,0,0], image_std=[1,1,1,1],
                                                                        num_channels=n_channels)
model.transform = trfrm

Thanks in advance for any help with this!