Fasterrcnn_mobilenet_v3_large_fpn for binary class classification

Hi,
Please suggest how to change the last layers to tweak the model fasterrcnn_mobilenet_v3_large_fpn for binary class only - e.g 0 for background and 1 for defect.

Here is the sample code.

import numpy as np
from torchmetrics.detection import IntersectionOverUnion
from torchmetrics.detection import MeanAveragePrecision
import math

class CocoDNN(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights="DEFAULT")

        self.metric = MeanAveragePrecision(iou_type="bbox",average="macro",class_metrics = True, iou_thresholds=[0.5, 0.75],extended_summary=True, backend="faster_coco_eval")  


    def forward(self, images, targets=None):
        return self.model(images, targets)

    def training_step(self, batch, batch_idx):
        imgs, annot = batch
        batch_losses = []
        for img_b, annot_b in zip(imgs, annot):
            #print(len(img_b), len(annot_b))
            if len(img_b) == 0:
                continue
            loss_dict = self.model(img_b, annot_b)
            losses = sum(loss for loss in loss_dict.values())
            #print(losses)
            batch_losses.append(losses)
        batch_mean  = torch.mean(torch.stack(batch_losses))
        self.log('train_loss', batch_mean, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return batch_mean
    
    def validation_step(self, batch, batch_idx):
        imgs, annot = batch
        targets ,preds = [], []
        for img_b, annot_b in zip(imgs, annot):
            if len(img_b) == 0:
                continue
            if len(annot_b)> 1:
                targets.extend(annot_b)
            else:
                targets.append(annot_b[0])

            #print(f"Annotated : {len(annot_b)} - {annot_b}")
            #print("")
            loss_dict = self.model(img_b, annot_b)
        
            #print(f"Predicted : {len(loss_dict)} -  {loss_dict}")
            if len(loss_dict)> 1:
                preds.extend(loss_dict)
            else:
                preds.append(loss_dict[0])
            #preds.append(loss_dict)

        self.metric.update(preds, targets)
        map_results = self.metric.compute()
        #self.log_dict('logs',map_results)
        
        print(map_results)
        #print(map_results['map_50'].float().item())
        self.log('recall', map_results['recall'].mean().float().item(),on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('map_50', map_results['map_50'].float().item(),on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('map_75', map_results['map_75'].float().item(),on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return map_results['map_75']

    def configure_optimizers(self):
        return optim.SGD(self.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)