I am finetuning a MaskRCNN (using the torchvision implementation, the one at Mask R-CNN — Torchvision main documentation) to perform instance segmentation of wires (powerline wires for example). However, after training, even though training loss goes down and validation loss also goes down, when evaluating I discover my model is not predicting any masks, independently of the image.
Evaluation code looks something like this (and the print statements at the bottom always print 0.0, and when inspecting the raw model output it really does not contain any predicted boxes/masks)
def eval(
val_dataset,
model,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
batch_size: int = 8,
) -> float:
val_dl = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=coco_collate_fn,
num_workers=4,
)
boxes_per_img = 0.0
masks_per_img = 0.0
NUM_IMGS = len(val_dataset)
with torch.no_grad():
model.eval()
for i, (imgs, targets) in enumerate(tqdm(val_dl, total=len(val_dl), position=1)):
imgs, targets = map_to_device(imgs, targets, device) # move to gpu
preds = model(imgs, targets)
boxes_per_img += sum(len(p["boxes"]) for p in preds) / NUM_IMGS
masks_per_img += sum(len(p["masks"]) for p in preds) / NUM_IMGS
model.train()
print(f"(Validation) Boxes per image: {boxes_per_img}")
print(f"(Validation) Masks per image: {masks_per_img}")
Does anyone know why this could happen? I would expect segmentation/detection errors but not zero predictions. I thought it could be class imbalance (wires are thin objects, but there is at least one in every image) and changed some model parameters to oversample positive regions, but it did not change the result.
# Load a pre-trained model
model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(
weights=torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1,
weights_backbone=torchvision.models.ResNet50_Weights.IMAGENET1K_V2,
)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, 1) # only one class: wire
# Increase the positive sample fraction to oversample wires
model.roi_heads.box_sampler_batch_size_per_image = 768 # Default is 512
model.roi_heads.box_positive_fraction = 0.8 # Default is 0.25