Hello,
I am trying to build a Mask RCNN model with a resnet101 backbone, however it seems the model does not want to work, because of my passed anchor_generator.
How I defined my model:
import torch
from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.anchor_utils import AnchorGenerator
bmodel = torchvision.models.resnet101(weights='ResNet101_Weights.IMAGENET1K_V2')
backbone = torch.nn.Sequential(*(list(bmodel.children())[:-2]))
backbone[0] = torch.nn.Conv2d(1, 64,kernel_size = 7, stride=2,padding= (3,3), bias=False)
anchor_sizes = ((6,), (9,), (16,), (32,), (64,))
aspect_ratios = ((1.0, 1.25, 1.5),) * len(anchor_sizes)
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0','1','2','3'],
output_size=7,
sampling_ratio=2)
mask_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0','1','2','3'],
output_size=14,
sampling_ratio=2)
model = MaskRCNN(backbone,
num_classes=2,
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler,
mask_roi_pool=mask_roi_pooler,
min_size=(256,), max_size=256,
box_detections_per_img = 1024,
box_batch_size_per_image= 1024,
trainable_backbone_layers = 2,)
model.rpn.head = RPNHead(256, anchor_generator.num_anchors_per_location()[0], conv_depth=2)
grcnn = torchvision.models.detection.transform.GeneralizedRCNNTransform(min_size=256, max_size=256, image_mean=[0.485], image_std=[0.229])
model.transform = grcnn
model.rpn.head.conv[0][0] = torch.nn.Conv2d(2048, 256,kernel_size = 7, stride=2,padding= (3,3), bias=False)
device= 'cuda'
model.to(device)
#
The model builds fine but then it throws an error when I try to start training:
AssertionError: Anchors should be Tuple[Tuple[int]] because each feature map could potentially have different sizes and aspect ratios. There needs to be a match between the number of feature maps passed and the number of sizes / aspect ratios specified.
Could anyone explain to me how do I find out the number of feature maps passed from the backbone? And how should I define the anchor generators? Thank you in advance.