Hi!
I’m working on a customized Faster RCNN to be used with 4 channel 512x512 images.
I’m running into a size mismatch when I try to run my data through, and I’m not sure why.
Here’s my code:
class CustomBackbone(torch.nn.Module):
def __init__(self, backbone, out_channels):
super().__init__()
self.body = backbone
self.out_channels = out_channels
def forward(self, x):
return self.body(x)
class ModifiedResNet50(resnet.ResNet):
'''
Modify ResNet to accept more channels and 512x512 images
'''
def __init__(self, num_channels=4, pretrained=False):
super().__init__(resnet.Bottleneck, [3, 4, 6, 3])
self.conv1 = torch.nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
if not pretrained:
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, torch.nn.BatchNorm2d):
torch.nn.init.constant_(m.weight, 1)
torch.nn.init.constant_(m.bias, 0)
class ModifiedFeaturePyramidNetwork(feature_pyramid_network.FeaturePyramidNetwork):
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
super().__init__(in_channels_list, out_channels, extra_blocks)
# Modify the lateral layers to match the backbone output channels
self.lateral_layers = torch.nn.ModuleList([
torch.nn.Conv2d(in_channels, out_channels, 1) for in_channels in in_channels_list
])
resnet50 = ModifiedResNet50(num_channels=4)
backbone_body = torchvision.models._utils.IntermediateLayerGetter(resnet50, return_layers=return_layers)
backbone = CustomBackbone(backbone_body, out_channels)
fpn = ModifiedFeaturePyramidNetwork(
in_channels_list=in_channels_list,
out_channels=out_channels,
extra_blocks=None,
)
model = FasterRCNN(backbone, num_classes=n_classes, rpn_anchor_generator=AnchorGenerator())
trfrm = torchvision.models.detection.transform.GeneralizedRCNNTransform(min_size=300, max_size=1000,
image_mean=[0,0,0,0], image_std=[1,1,1,1],
num_channels=n_channels)
model.transform = trfrm
Thanks in advance for any help with this!