Loss function for bounding boxes in rpn model

Hi, I am learning pytorch and ML learning. Now, I am working on rpn model that should paint bounding boxes around a specific objects in images. I have only 2 classes: target objects and background.
I created a custom rpn structure and checked, everything work correctly. I have a backbone that provides feature maps and two heads, classification and regression. I work with the images of constant shape, thus, the initial anchors were generated only once. The output of classification head is [bs, X] and the output of regression head is [bs, X, 4] (offset), where X is total number of unchors.
Using offset and anchosr i build the bboxes

    def postprocess_(self, offset):
        if offset.device != self.anchors.device:
            self.anchors = self.anchors.to(offset.device)
        bboxes = []
        for ind, preds in enumerate(offset):
            x_center = self.anchors[0][..., 0] + preds[..., 0]* self.anchors[0][..., 0]   # New X centers
            y_center = self.anchors[0][..., 1] + preds[..., 1]* self.anchors[0][..., 1]  # New Y centers
            width = self.anchors[0][..., 2]*preds[..., 2] # New width
            height = self.anchors[0][..., 3]*preds[..., 3] # New height
            bbox = torch.stack([x_center, y_center, width, height], dim=-1)
        return torch.stack(bboxes, dim=0).to(offset.device, non_blocking = True)

But, I have an issue with the loss function. The model doesnt want to tune. I see the dectreaseof the regression error, while the classification error is almost unchangable.The code of the loss class is below (this is draft)

class RPNLoss(torch.nn.Module):
    def __init__(self, pos_iou_thresh=0.3, neg_iou_thresh=0.1, lambda_reg=1.0, cls_loss = None):

        self.pos_iou_thresh = pos_iou_thresh
        self.neg_iou_thresh = neg_iou_thresh
        self.lambda_reg = lambda_reg
        self.cls_loss = F.binary_cross_entropy_with_logits if cls_loss is None else cls_loss

    def forward(self, anchors, preds, gt_bboxes, offset):
        probs = preds['probs']  
        bboxes = preds['bboxes']  
        device = anchors.device

        cls_pos_targets = torch.zeros_like(probs, device=device)  
        cls_neg_targets = torch.zeros_like(probs, device=device)  
        reg_targets = torch.zeros_like(bboxes, device=device)   
        reg_mask = torch.zeros_like(probs, device=device)  

        iou_loss = 0
        reg_loss = 0
        cls_loss = 0
        for b in range(bs):
            gt = gt_bboxes[b]  
            gt = gt[~torch.all(gt == 0, dim=-1)]  
            if len(gt) == 0:
                continue  # if we do not have groundtruth
            ious = self.compute_iou(bboxes[b], gt)  # calc IoU between bboxes and groundtruth

            max_ious, gt_indices = ious.max(dim=1) 
            # Positive boxes
            pos_mask = max_ious >= self.pos_trh
            iou_loss += (1 -max_ious[pos_mask].sum()/pos_mask.sum())
            neg_mask = max_ious < self.neg_trh
            #update pos and neg targets
            cls_pos_targets[b, pos_mask] = 1
            cls_neg_targets[b, neg_mask] = 1

            # regression targets (use only positive anchors, boxes and groundtruth)
            pos_anchors = anchors[0][pos_mask]
            pos_gt = gt[gt_indices[pos_mask]]
            reg_targets[b, pos_mask] = self.box_to_deltas(pos_anchors, pos_gt) # calculate the required offset between the anchors and their groundtruth
            pred_delta = self.box_to_deltas(pos_anchors,bboxes[b, pos_mask]) # predicte offset between the anchors and bboxes
            reg_loss += F.smooth_l1_loss(pred_delta, reg_targets[b, pos_mask], reduction="sum")/max(1, pos_mask.sum()) # calculate regression loss

        # classification losses. Most probably the error is here
        total = cls_pos_targets.sum() + cls_neg_targets.sum()
        pos_count = cls_pos_targets.sum().clamp(min=1)
        neg_count = cls_neg_targets.sum().clamp(min=1)
        total_count = pos_count + neg_count
        cls_pos_loss = focal(probs, cls_pos_targets).sum()
        cls_neg_loss = focal(-probs, cls_neg_targets).sum()/neg_count
        cls_loss = cls_pos_loss +  cls_neg_loss
        # Total loss
        total_loss = cls_loss/total + self.lambda_reg * reg_loss 
        # print(f'classification loss {cls_loss/total} and regression loss {reg_loss}, iou loss = {iou_loss}')
        return total_loss
    def compute_iou(boxes1, boxes2):
        Calc iou
        boxes1: [N, 4]
        boxes2: [M, 4]
        boxes1 = convert_to_corners(boxes1)
        boxes2 = convert_to_corners(boxes2)
        inter_x1 = torch.max(boxes1[:, None, 0], boxes2[:, 0])
        inter_y1 = torch.max(boxes1[:, None, 1], boxes2[:, 1])
        inter_x2 = torch.min(boxes1[:, None, 2], boxes2[:, 2])
        inter_y2 = torch.min(boxes1[:, None, 3], boxes2[:, 3])

        inter_area = torch.clamp(inter_x2 - inter_x1, min=0) * torch.clamp(inter_y2 - inter_y1, min=0)
        area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
        area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
        union_area = area1[:, None] + area2 - inter_area
        # union_area = torch.clamp(union_area, min=1e-6) 
        return inter_area/(union_area + 1e-3)

    def box_to_deltas(anchors, gt_boxes):
        calc offset (tx, ty, tw, th).
        anchors: [N, 4]
        gt_boxes: [N, 4]
        dx = ( gt_boxes[:, 0] -anchors[:, 0]) / (anchors[:, 0]+1e-6)
        dy = (gt_boxes[:, 1]-anchors[:, 1 ]) / (anchors[:, 1]+1e-6)
        dw = gt_boxes[:, 2] /(anchors[:, 2]+1e-6)
        dh = gt_boxes[:, 3] /(anchors[:, 3]+1e-6)
        return torch.stack([dx, dy, dw, dh], dim=-1)

def convert_to_corners( boxes):
        boxex convertation from [xc, yc, w, h] in [x1, y1, x2, y2].
        boxes: [N, 4] — [xc, yc, w, h].
        xc, yc, w, h = boxes[..., 0], boxes[..., 1], boxes[..., 2], boxes[..., 3]
        x1 = xc - w / 2
        y1 = yc - h / 2
        x2 = xc + w / 2
        y2 = yc + h / 2
        return torch.stack([x1, y1, x2, y2], dim=-1)

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, reduction='sum'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=self.reduction)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        if self.reduction == 'mean':
            return F_loss.mean()
        elif self.reduction == 'sum':
            return F_loss.sum()
            return F_loss

What do I do wrong? I am ready to decide that ML is not mine…
I understand thet it should minimize difference between the target and predicted values, but WTF?

All work. Just was need to find appropriate contribution coeficient for the regression lost and for the focal losses

Hmmm. I was too optimistic. It doesnt work properly. Everuthing is ok with regression lose but i cant handle the classification one