1 channel faster r-cnn

I’m trying to implement Faster R-CNN with ResNet50 as backbone. This is the class of my model:

class MyFasterRCNN(nn.Module):
    def __init__(self, num_classes=5):
        super(MyFasterRCNN, self).__init__()

        backbone = resnet50(pretrained=False)
        backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        backbone_with_fpn = BackboneWithFPN(
            backbone, 
            return_layers={'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'},
            in_channels_list=[256, 512, 1024, 2048],
            out_channels=256
        )

        # Define the Region Proposal Network (RPN) anchor generator
        rpn_anchor_generator = AnchorGenerator(
            sizes=((16, 64, 128, 256, 512),),  # Added a larger size
            aspect_ratios=((0.1, 0.5, 1.0, 2.0, 4.0),) * 5  # Added a larger aspect ratio for extremely tall objects
        )

        # Define the ROI Pooling feature extractor
        roi_pooler = torchvision.ops.MultiScaleRoIAlign(
            featmap_names=['0', '1', '2', '3'],
            output_size=7,
            sampling_ratio=2
        )

        # Create the Faster R-CNN model
        self.model = FasterRCNN(
            backbone_with_fpn,
            num_classes=num_classes,
            rpn_anchor_generator=rpn_anchor_generator,
            box_roi_pool=roi_pooler
        )

    def forward(self, x):
        if len(x) == 2:   
            images = [ img for img in x[0] ]
            targets = [ element for element in x[1] ]
                
            return self.model(images, targets)
        else:
            return self.model(x)

My images are single channel so I have updated the first layer of the resnet model.

This debug log displays the input values of a batch with batch size set to 2 (it seems to me that the input is what the model is expecting)
The images are shrinked to 512x512 before been fed to the model

images
[tensor([[[-2.1194, -2.1779, -2.2365, …, -1.7095, -1.3972, -1.1435], [-2.3…7532, …, -0.4799, -0.4995, -0.5190]]]), tensor([[[-0.5775, -0.5580, -0.5190, …, -0.7337, -0.7532, -0.7922], [0.5…6705, …, -2.2560, -2.2365, -2.2169]]])]
len(images)
2
images[0].shape
torch.Size([1, 512, 512])
images[1].shape
torch.Size([1, 512, 512])
targets
[{‘boxes’: tensor([[310., 131., 378., 133.]]), ‘labels’: tensor([4])}, {‘boxes’: tensor([[104., 67., 126., 438.], [376., 105., 387., 437.]]), ‘labels’: tensor([1, 1])}]

This is the error i get:

  File "C:\Users\hidri\MLA\AM04\src\models\fasterrcnn.py", line 55, in forward
    return self.model(images, targets)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\hidri\MLA\AM04\src\modules\object_detection_module.py", line 35, in forward
    return self.model(x)
           ^^^^^^^^^^^^^
  File "C:\Users\hidri\MLA\AM04\src\modules\object_detection_module.py", line 39, in training_step
    predictions = self((images, targets))
                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\hidri\MLA\AM04\src\train_object_detection.py", line 83, in <module> trainer.fit(model=module, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
RuntimeError: Given groups=1, weight of size [64, 1, 7, 7], expected input[2, 3, 800, 800] to have 1 channels, but got 3 channels instead

Can someone help me to understand what I’m missing?

Hi,

When you create an object of your model and look into layers, for example by using:

print(model)

You will see at the beggining:

MyFasterRCNN(
  (model): FasterRCNN(
    (transform): GeneralizedRCNNTransform(
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        Resize(min_size=(800,), max_size=1333, mode='bilinear')

This means, there are build-in transforms for 3-channel images.
To change the transforms for 1-channel image, you could do something like that:

from torchvision.models.detection.transform import GeneralizedRCNNTransform
import torch

image_mean = [0.485]
image_std = [0.229]
transforms =GeneralizedRCNNTransform(800,1333,image_mean,image_std)
model.model.transform = transforms
2 Likes

This definitely solved the problem. We updated the mean and the std to [0] and [1] since we already perform ourselves the normalization of the images.
Unfortunately the model is not learning.

Here you can find the current network implementation (we changed the backbone to ResNet18 since in the same domain for a classification task it was the CNN to best perform):

import torch
import torch.nn as nn
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.backbone_utils import BackboneWithFPN
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from typing import Tuple, List, Dict, Optional
from torch import Tensor
from collections import OrderedDict
from torchvision.models.detection.roi_heads import fastrcnn_loss
from torchvision.models.detection.rpn import concat_box_prediction_layers


class MyFasterRCNN(nn.Module):
    def __init__(self, num_classes=5):
        super(MyFasterRCNN, self).__init__()

        resnet_net = torchvision.models.resnet18(weights=None)
        resnet_net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        modules = list(resnet_net.children())[:-2]
        backbone = nn.Sequential(*modules)
        
        backbone_with_fpn = BackboneWithFPN(
            backbone, 
            return_layers={'4': '0', '5': '1', '6': '2', '7': '3'},
            in_channels_list=[64, 128, 256, 512],
            out_channels=512
        )

        # Define the Region Proposal Network (RPN) anchor generator
        rpn_anchor_generator = AnchorGenerator(
            sizes= ((8,), (16,), (16,), (32,), (64,)),  # Sizes for each feature map
            aspect_ratios=((0.03, 1.0, 5, 10),) * 5  # Aspect ratios for each feature map
        )

        # Define the ROI Pooling feature extractor
        roi_pooler = torchvision.ops.MultiScaleRoIAlign(
            featmap_names=['0', '1', '2', '3'],
            output_size=7,
            sampling_ratio=2
        )

        # Create the Faster R-CNN model
        self.model = FasterRCNN(
            backbone_with_fpn,
            num_classes=num_classes,
            rpn_anchor_generator=rpn_anchor_generator,
            box_roi_pool=roi_pooler,
            box_detections_per_img=50,
        )
        
        mean = [0]
        std = [1]
        transforms = GeneralizedRCNNTransform(800,1333,mean,std)
        self.model.transform = transforms

    def forward(self, images, targets=None):
        if targets is not None:
            return self.model(images, targets)
        else:
            return self.model(images)

Meanwhile, for what regards the evaluation metrics I use mAP in this fashion:

def on_test_epoch_end(self):
        threshold = 0.02
        num_classes = 5
        ap_per_class = {k: 0 for k in range(num_classes)}
        occurances_per_class = {k: 0 for k in range(num_classes)}
        predictions_per_class = {k: pd.DataFrame(columns=['iou', 'correct', 'precision', 'recall']) for k in range(num_classes)}

        # Check predictions
        for _, prediction, target in self.test_outputs:
            pred_boxes = prediction['boxes'].cpu().detach().numpy()
            pred_labels = prediction['labels'].cpu().detach().numpy()
            pred_scores = prediction['scores'].cpu().detach().numpy()

            target_boxes = target['boxes'].cpu().detach().numpy()
            target_labels = target['labels'].cpu().detach().numpy()

            # For each prediction check if it's correct
            ious = box_iou(torch.Tensor(pred_boxes), torch.Tensor(target_boxes))
            target_idx = np.argmax(ious, axis=1)
            corrects = pred_labels == target_labels[target_idx]
            
            iou_values = ious[np.arange(ious.shape[0]), target_idx]
            ious = iou_values.numpy().reshape(-1, 1).flatten()

            # NMS to eliminate overlapping predictions
            keep = torchvision.ops.nms(torch.Tensor(pred_boxes), torch.Tensor(ious), threshold)

            for p in keep:
                c = pred_labels[p]
                if ious[p] > threshold:
                    new_row = {'iou': ious[p], 'correct': corrects[p], 'precision': 0, 'recall': 0}
                    predictions_per_class[c].loc[len(predictions_per_class[c])] = new_row
                    predictions_per_class[c] = predictions_per_class[c].reset_index(drop=True)

            unique_labels, counts = np.unique(target_labels, return_counts=True)
            label_counts = dict(zip(unique_labels, counts))
            for label, count in label_counts.items():
                occurances_per_class[label] += count


        for c in range(num_classes):
            if len(predictions_per_class[c]) == 0:
                ap_per_class[c] = 0
                continue
            predictions_per_class[c] = predictions_per_class[c].sort_values(by='iou', ascending=False)
            TP = 0
            FP = 0

            for i, row in predictions_per_class[c].iterrows():
                TP += row['correct']
                FP += not row['correct']
                row['recall'] = TP / occurances_per_class[c]
                row['precision'] = TP / (TP + FP)

            precision = predictions_per_class[c]['precision'].values
            recall = predictions_per_class[c]['recall'].values
            pr_auc = auc(recall, precision)

            # Average Precision
            ap_per_class[c] = pr_auc

        mAP = np.array(list(ap_per_class.values())).mean()
        self.log(f'mAP@{threshold}', mAP, logger=True, prog_bar=True, on_step=False, on_epoch=True)

Input images are 512x512 1 channel, maximum objects to be detected for each image is 6 on the train and validation sets, whereas in the the test set can go up to 20.
Moreover we do not have high computation power and training this model is requiring quite some time, some suggestion also for this problem could help us speed up the debugging phase.

The main problem is that i obtain a mAP of 0. The network seems not learning.
Can someone suggest what can we change? Is the calculation of mAP correct?

The hyperparameters are:

  1. lr=1e-5
  2. batch_size=4
  3. optimizer=adam
  4. scheduler=plateau
  5. epochs=20 (the loss isn’t decreasing after this point)