Mask RCNN fine-tuning problem

I want to fine-tune a Mask RCNN model with ResNet50 backbone, but the model isn’t converging at all.
Here’s a sample code and results from 1st and 5th epoch:

num_classes = 1
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained = True)

for param in model.parameters():
    param.requires_grad = False

in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = model.to(device)

optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr = 0.001, momentum = 0.9, weight_decay = 0.0005)

class MyDataset(Dataset):
    def __init__(self, df, DATA_PATH, transform = None):
        self.transform = transform
        self.df = df
        self.DATA_PATH = DATA_PATH
        self.length = len(os.listdir(self.DATA_PATH))
        self.items = list(self.df['id'].unique())

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        item = self.items[idx]
        image = Image.open(os.path.join(DATA_PATH, 'train', item + '.png'))
        image = torchvision.transforms.ToTensor()(image)
        bbox, mask = self.get_bbox_mask(item)
        bbox = torch.Tensor(bbox).reshape(-1, 4)
        bbox[:, 2] = bbox[:, 0] + bbox[:, 2]
        bbox[:, 3] = bbox[:, 1] + bbox[:, 3]
        mask = torch.Tensor(mask).reshape(-1, 704, 520)
        return image, bbox, mask
        
    def get_bbox_mask(self, item):
        return create_bbox_list_by_image(self.df, item), create_mask(self.df, item, 704, 520)

DATA_PATH = "path/to/data"


train = MyDataset(pd.read_csv(os.path.join(DATA_PATH, 'train.csv')), os.path.join(DATA_PATH, 'train'))

batch_size = 1
num_epochs = 200

trainloader = torch.utils.data.DataLoader(train, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

for epoch in tqdm(range(num_epochs)):
    running_loss = []
    loss_classifier = []
    loss_box_reg = []
    loss_mask = []
    loss_objectness = []
    loss_rpn_box_reg = []
    for batch in tqdm(trainloader):
        image, bbox_target, mask_target = batch
        image = image.to(device)
        
        bbox_target = bbox_target.reshape(-1, 4)
        target = {}
        target["boxes"] = bbox_target
        target["masks"] = mask_target
        target["labels"] = torch.ones(bbox_target.shape[0], dtype=torch.int64)
        target = [target]
        for key, value in target[0].items():
            target[0][key] = target[0][key].to(device)
        optimizer.zero_grad()
        loss_dict = model(image, target)
        loss_classifier.append(loss_dict['loss_classifier'])
        loss_box_reg.append(loss_dict['loss_box_reg'])
        loss_mask.append(loss_dict['loss_mask'])
        loss_objectness.append(loss_dict['loss_objectness'])
        loss_rpn_box_reg.append(loss_dict['loss_rpn_box_reg'])
        losses = sum(loss for loss in loss_dict.values())
        running_loss.append(losses)
        losses.backward()
        optimizer.step()
    print(f"Epoch {epoch + 1}, loss: {sum(running_loss)/len(running_loss)}")
    print(f"Loss classifier: {sum(loss_classifier)/len(loss_classifier)}")
    print(f"Loss box regression: {sum(loss_box_reg)/len(loss_box_reg)}")
    print(f"Loss mask: {sum(loss_mask)/len(loss_mask)}")
    print(f"Loss objectness: {sum(loss_objectness)/len(loss_objectness)}")
    print(f"Loss rpn box regression: {sum(loss_rpn_box_reg)/len(loss_rpn_box_reg)}")

First thing that comes to mind is that I’m using batch size of 1 (both for hardware limitation and because of the fact that not every image has the same amount of bounding boxes, so I would need to pad bounding boxes to the same shape, but ResNet50 (used as a backbone in Mask RCNN) uses Batch Normalization, that doesn’t make much sense with batch size of 1. As far as I know, there are no pretrained backbones that use Group Normalization instead of batch normalization, so my options are kinda limited here.

Of course, if there’s anything wrong with the code here, please tell me, I wonder why my model doesn’t want to converge