Adjust prediction layer of maskrcnn_resnet50_fpn for transfer learning and fine tuning

My Approach

I want to re-train a pre trained Object Detection Model using transfer learning and fine tune it further to improve the network for detection for the custom data that I currently have.

For inferencing: to check which model has already learnt features which help it at least detect(draw bounding box) around my custom data.
Inferencing source

# import necessary libraries
%matplotlib inline
import matplotlib.pyplot as plt 
from PIL import Image
import torch
import torchvision.transforms as T
import torchvision
import numpy as np 
import cv2
import warnings
warnings.filterwarnings('ignore')

# load model
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

# set to evaluation mode
model.eval()


# load the COCO dataset category names
# we will use the same list for this notebook
COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

def get_prediction(img_path, confidence):
  """
  get_prediction
    parameters:
      - img_path - path of the input image
      - confidence - threshold value for prediction score
    method:
      - Image is obtained from the image path
      - the image is converted to image tensor using PyTorch's Transforms
      - image is passed through the model to get the predictions
      - class, box coordinates are obtained, but only prediction score > threshold
        are chosen.
    
  """
  img = Image.open(img_path)
  transform = T.Compose([T.ToTensor()])
  img = transform(img)
  pred = model([img])
  pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
  pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].detach().numpy())]
  pred_score = list(pred[0]['scores'].detach().numpy())
  pred_t = [pred_score.index(x) for x in pred_score if x>confidence][-1]
  pred_boxes = pred_boxes[:pred_t+1]
  pred_class = pred_class[:pred_t+1]
  return pred_boxes, pred_class

def detect_object(img_path,image_save_path, confidence=0.5, rect_th=1, text_size=1, text_th=1):
  """
  object_detection_api
    parameters:
      - img_path - path of the input image
      - confidence - threshold value for prediction score
      - rect_th - thickness of bounding box
      - text_size - size of the class label text
      - text_th - thichness of the text
    method:
      - prediction is obtained from get_prediction method
      - for each prediction, bounding box is drawn and text is written 
        with opencv
      - the final image is displayed
  """
  boxes, pred_cls = get_prediction(img_path, confidence)
  img = cv2.imread(img_path)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  # print(len(boxes))
  for i in range(len(boxes)):
    cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0), thickness=rect_th)
    cv2.putText(img,pred_cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th)

  cv2.imwrite(image_save_path, cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
  plt.figure(figsize=(20,30))
  plt.imshow(img)
  plt.xticks([])
  plt.yticks([])
  plt.show()

detect_object(source_file_path,save_file_path, confidence=0.18)

The features seem to align with

 "bottle","cell phone", "book"

For training:

I am looking into Building your own object detector — PyTorch vs TensorFlow and how to even get started? where all the input features from class scores of box_predictor in roi_heads of fasterrcnn_resnet50_fpn pre trained models is set as divided by number of classes the

def get_model(num_classes):
   # load an object detection model pre-trained on COCO
   model = torchvision.models.detection.
           fasterrcnn_resnet50_fpn(pretrained=True)
# get the number of input features for the classifier
   in_features = model.roi_heads.box_predictor.cls_score.in_features
   # replace the pre-trained head with a new on
   model.roi_heads.box_predictor = FastRCNNPredictor(in_features,/
   num_classes)
   
   return model

maskrcnn_resnet50_fpn

Checked all the parameters those requires_gradient

# Load model
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

# set model to evaluation mode
model.eval()

# List out all the name of the parameters whose gradient can be altered for further training
for name, param in model.named_parameters():

    # If requires gradient parameters
    if param.requires_grad:

        # Name and value of the parameter
        print("Name: ",name,"\nBool: ",param.requires_grad,"\n")

OUTPUT

Name:  backbone.body.layer2.0.conv1.weight 
Bool:  True 

Name:  backbone.body.layer2.0.conv2.weight 
Bool:  True 

Name:  backbone.body.layer2.0.conv3.weight 
Bool:  True 

Name:  backbone.body.layer2.0.downsample.0.weight 
Bool:  True 

Name:  backbone.body.layer2.1.conv1.weight 
Bool:  True 

Name:  backbone.body.layer2.1.conv2.weight 
Bool:  True 

Name:  backbone.body.layer2.1.conv3.weight 
Bool:  True 

Name:  backbone.body.layer2.2.conv1.weight 
Bool:  True 

Name:  backbone.body.layer2.2.conv2.weight 
Bool:  True 

Name:  backbone.body.layer2.2.conv3.weight 
Bool:  True 

Name:  backbone.body.layer2.3.conv1.weight 
Bool:  True 

Name:  backbone.body.layer2.3.conv2.weight 
Bool:  True 

Name:  backbone.body.layer2.3.conv3.weight 
Bool:  True 

Name:  backbone.body.layer3.0.conv1.weight 
Bool:  True 

Name:  backbone.body.layer3.0.conv2.weight 
Bool:  True 

Name:  backbone.body.layer3.0.conv3.weight 
Bool:  True 

Name:  backbone.body.layer3.0.downsample.0.weight 
Bool:  True 

Name:  backbone.body.layer3.1.conv1.weight 
Bool:  True 

Name:  backbone.body.layer3.1.conv2.weight 
Bool:  True 

Name:  backbone.body.layer3.1.conv3.weight 
Bool:  True 

Name:  backbone.body.layer3.2.conv1.weight 
Bool:  True 

Name:  backbone.body.layer3.2.conv2.weight 
Bool:  True 

Name:  backbone.body.layer3.2.conv3.weight 
Bool:  True 

Name:  backbone.body.layer3.3.conv1.weight 
Bool:  True 

Name:  backbone.body.layer3.3.conv2.weight 
Bool:  True 

Name:  backbone.body.layer3.3.conv3.weight 
Bool:  True 

Name:  backbone.body.layer3.4.conv1.weight 
Bool:  True 

Name:  backbone.body.layer3.4.conv2.weight 
Bool:  True 

Name:  backbone.body.layer3.4.conv3.weight 
Bool:  True 

Name:  backbone.body.layer3.5.conv1.weight 
Bool:  True 

Name:  backbone.body.layer3.5.conv2.weight 
Bool:  True 

Name:  backbone.body.layer3.5.conv3.weight 
Bool:  True 

Name:  backbone.body.layer4.0.conv1.weight 
Bool:  True 

Name:  backbone.body.layer4.0.conv2.weight 
Bool:  True 

Name:  backbone.body.layer4.0.conv3.weight 
Bool:  True 

Name:  backbone.body.layer4.0.downsample.0.weight 
Bool:  True 

Name:  backbone.body.layer4.1.conv1.weight 
Bool:  True 

Name:  backbone.body.layer4.1.conv2.weight 
Bool:  True 

Name:  backbone.body.layer4.1.conv3.weight 
Bool:  True 

Name:  backbone.body.layer4.2.conv1.weight 
Bool:  True 

Name:  backbone.body.layer4.2.conv2.weight 
Bool:  True 

Name:  backbone.body.layer4.2.conv3.weight 
Bool:  True 

Name:  backbone.fpn.inner_blocks.0.weight 
Bool:  True 

Name:  backbone.fpn.inner_blocks.0.bias 
Bool:  True 

Name:  backbone.fpn.inner_blocks.1.weight 
Bool:  True 

Name:  backbone.fpn.inner_blocks.1.bias 
Bool:  True 

Name:  backbone.fpn.inner_blocks.2.weight 
Bool:  True 

Name:  backbone.fpn.inner_blocks.2.bias 
Bool:  True 

Name:  backbone.fpn.inner_blocks.3.weight 
Bool:  True 

Name:  backbone.fpn.inner_blocks.3.bias 
Bool:  True 

Name:  backbone.fpn.layer_blocks.0.weight 
Bool:  True 

Name:  backbone.fpn.layer_blocks.0.bias 
Bool:  True 

Name:  backbone.fpn.layer_blocks.1.weight 
Bool:  True 

Name:  backbone.fpn.layer_blocks.1.bias 
Bool:  True 

Name:  backbone.fpn.layer_blocks.2.weight 
Bool:  True 

Name:  backbone.fpn.layer_blocks.2.bias 
Bool:  True 

Name:  backbone.fpn.layer_blocks.3.weight 
Bool:  True 

Name:  backbone.fpn.layer_blocks.3.bias 
Bool:  True 

Name:  rpn.head.conv.weight 
Bool:  True 

Name:  rpn.head.conv.bias 
Bool:  True 

Name:  rpn.head.cls_logits.weight 
Bool:  True 

Name:  rpn.head.cls_logits.bias 
Bool:  True 

Name:  rpn.head.bbox_pred.weight 
Bool:  True 

Name:  rpn.head.bbox_pred.bias 
Bool:  True 

Name:  roi_heads.box_head.fc6.weight 
Bool:  True 

Name:  roi_heads.box_head.fc6.bias 
Bool:  True 

Name:  roi_heads.box_head.fc7.weight 
Bool:  True 

Name:  roi_heads.box_head.fc7.bias 
Bool:  True 

Name:  roi_heads.box_predictor.cls_score.weight 
Bool:  True 

Name:  roi_heads.box_predictor.cls_score.bias 
Bool:  True 

Name:  roi_heads.box_predictor.bbox_pred.weight 
Bool:  True 

Name:  roi_heads.box_predictor.bbox_pred.bias 
Bool:  True 

Name:  roi_heads.mask_head.mask_fcn1.weight 
Bool:  True 

Name:  roi_heads.mask_head.mask_fcn1.bias 
Bool:  True 

Name:  roi_heads.mask_head.mask_fcn2.weight 
Bool:  True 

Name:  roi_heads.mask_head.mask_fcn2.bias 
Bool:  True 

Name:  roi_heads.mask_head.mask_fcn3.weight 
Bool:  True 

Name:  roi_heads.mask_head.mask_fcn3.bias 
Bool:  True 

Name:  roi_heads.mask_head.mask_fcn4.weight 
Bool:  True 

Name:  roi_heads.mask_head.mask_fcn4.bias 
Bool:  True 

Name:  roi_heads.mask_predictor.conv5_mask.weight 
Bool:  True 

Name:  roi_heads.mask_predictor.conv5_mask.bias 
Bool:  True 

Name:  roi_heads.mask_predictor.mask_fcn_logits.weight 
Bool:  True 

Name:  roi_heads.mask_predictor.mask_fcn_logits.bias 
Bool:  True 

I froze all the requires_gradient by setting them to FALSE except roi_heads.box_predictor.cls_score.weight and roi_heads.box_predictor.cls_score.bias those require tuning for new set of data to be trained

# List out all the name of the parameters whose gradient can be altered for further training
for name, param in model.named_parameters():

    # If requires gradient parameters
    if param.requires_grad:
        if name != "roi_heads.box_predictor.cls_score.weight" and name !="roi_heads.box_predictor.cls_score.bias":
            param.requires_grad = False

        # Name and value of the parameter
        print("Name: ",name,"\nBool: ",param.requires_grad,"\n")

My current case for confusion:

I am training currently for only one class

My current custom data gets mostly detected as cell phone and few other times as bottle and book

Should I follow same approach as shown in Building your own object detector — PyTorch vs TensorFlow and how to even get started? where all the input features from class scores of box_predictor in roi_heads of fasterrcnn_resnet50_fpn pre trained models is set as divided by number of classes?

# get the number of input features for the classifier
   in_features = model.roi_heads.box_predictor.cls_score.in_features
   # replace the pre-trained head with a new on
   model.roi_heads.box_predictor = FastRCNNPredictor(in_features,/
   num_classes)

Since cell phone features seem to be good for detection of custom data I currently have, is there any other way I can extract the weight and bias of the class cell phone for it to be used as weights and bias of this layers roi_heads.box_predictor.cls_score.weights and roi_heads.box_predictor.cls_score.bias?