YOLOv5 from scratch is not converging after 50 epochs

I have implemented yolov5m from pseudo-scratch and I am having troubles to debug the loss function. In particular after 50 epochs on coco128 (first 128 images of MS COCO) my net is having a MAP of 0, objectness accuracy of 0% and no-obj accuracy near 100 %.
I checked the code before the loss fn and targets are built correctly (checked by plotting images and the boxes) and therefore I concluded that the issue is in the loss fx.
I created my loss function it by merging some parts of yolov5 ultralytics loss fn and other parts by Aladdin Persson.

The result is the following:

class YOLO_LOSS:
    def __init__(self, model):

        self.mse = nn.MSELoss()
        self.bce = nn.BCEWithLogitsLoss()
        self.entropy = nn.CrossEntropyLoss()
        self.sigmoid = nn.Sigmoid()

        self.lambda_class = 1
        self.lambda_noobj = 1
        self.lambda_obj = 1
        self.lambda_box = 10

        anchors = getattr(model.head, "anchors").clone().detach()
        self.anchors = torch.cat((anchors[0], anchors[1], anchors[2]), dim=0)

        self.na = self.anchors.shape[0]
        self.num_anchors_per_scale = self.na // 3
        self.S = getattr(model.head, "stride")
        self.ignore_iou_thresh = 0.5
        self.ph = None  # this variable is used in the build_targets method, defined here for readability.
        self.pw = None  # this variable is used in the build_targets method, defined here for readability.

    def __call__(self, preds, targets, pred_size):
        # list of lists --> [pred[0].height, pred[0].width, pred[1].height... etc]

        targets = [self.build_targets(preds, bboxes, pred_size) for bboxes in targets]

        t1 = torch.stack([target[0] for target in targets], dim=0).to(config.DEVICE)
        t2 = torch.stack([target[1] for target in targets], dim=0).to(config.DEVICE)
        t3 = torch.stack([target[2] for target in targets], dim=0).to(config.DEVICE)

        anchors = self.anchors.reshape(3, 3, 2).to(config.DEVICE)

        loss = (self.compute_loss(preds[0], t1, anchors=anchors[0])
                + self.compute_loss(preds[1], t2, anchors=anchors[1])
                + self.compute_loss(preds[2], t3, anchors=anchors[2]))

        return loss

    def build_targets(self, input_tensor, bboxes, pred_size, check_loss=True):

        ph = pred_size[0]
        pw = pred_size[1]

        targets = [
            torch.zeros((self.num_anchors_per_scale, input_tensor[i].shape[2], input_tensor[i].shape[3], 6))
            for i in range(len(self.S))

        classes = [box[-1] - 1 for box in bboxes]  # classes in coco start from 1
        bboxes = [box[:-1] for box in bboxes]
        bboxes = rescale_bboxes(bboxes, starting_size=(640, 640), ending_size=(pw, ph))

        for idx, box in enumerate(bboxes):
            class_label = classes[idx]
            box = coco_to_yolo(box, pw, ph)

            iou_anchors = iou_width_height(torch.tensor(box[2:4]),
                                           self.anchors / torch.tensor([640, 640]).to(config.DEVICE))
            anchor_indices = iou_anchors.argsort(descending=True, dim=0)

            x, y, width, height, = box
            has_anchor = [False] * 3

            for anchor_idx in anchor_indices:
                # i.e if the best anchor idx is 8, num_anchors_per_scale
                # we know that 8//3 = 2 --> the best scale_idx is 2 -->
                # best_anchor belongs to last scale (52,52)
                # scale_idx will be used to slice the variable "targets"
                # another pov: scale_idx searches the best scale of anchors
                scale_idx = torch.div(anchor_idx, self.num_anchors_per_scale, rounding_mode="floor")
                # print(scale_idx)
                # anchor_on_scale searches the idx of the best anchor in a given scale
                # found via index in the line below
                anchor_on_scale = anchor_idx % self.num_anchors_per_scale
                # slice anchors based on the idx of the best scales of anchors
                scale_x = input_tensor[int(scale_idx)].shape[2]
                scale_y = input_tensor[int(scale_idx)].shape[3]
                i, j = int(scale_y * y), int(scale_x * x)  # which cell
                anchor_taken = targets[scale_idx][anchor_on_scale, i, j, 0]
                if not anchor_taken and not has_anchor[scale_idx]:
                    targets[scale_idx][anchor_on_scale, i, j, 0] = 1
                    x_cell, y_cell = scale_x * x - j, scale_y * y - i  # both between [0,1]
                    width_cell, height_cell = (
                        width * scale_x,
                        height * scale_y,
                    )  # can be greater than 1 since it's relative to cell
                    box_coordinates = torch.tensor(
                        [x_cell, y_cell, width_cell, height_cell]
                    targets[scale_idx][anchor_on_scale, i, j, 1:5] = box_coordinates
                    targets[scale_idx][anchor_on_scale, i, j, 5] = int(class_label)
                    has_anchor[scale_idx] = True
                elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
                    targets[scale_idx][anchor_on_scale, i, j, 0] = -1  # ignore prediction

        return targets

    def compute_loss(self, preds, targets, anchors):
        # originally anchors have shape (3,2) --> 3 set of anchors of width and height
        anchors = anchors.reshape(1, 3, 1, 1, 2)

        # because of https://github.com/ultralytics/yolov5/issues/471
        xy = preds[..., 1:3].sigmoid() * 2 - 0.5
        wh = (preds[..., 3:5].sigmoid() * 2) ** 2 * anchors

        # Check where obj and noobj (we ignore if target == -1)
        obj = targets[..., 0] == 1  # in paper this is Iobj_i
        noobj = targets[..., 0] == 0  # in paper this is Inoobj_i

        # ======================= #
        #   FOR NO OBJECT LOSS    #
        # ======================= #

        # not doing sigmoid because self.bce is bce_with_logits
        no_object_loss = self.bce(
            # [..., 0:1] instead of [..., 0] to keep shape untouched
            (preds[..., 0:1][noobj]), (targets[..., 0:1][noobj]),
        # ==================== #
        #   FOR OBJECT LOSS    #
        # ==================== #
        # https://www.youtube.com/watch?v=Grir6TZbc1M&t=5668s
        # explanation of anchors at 12:17:
        # dim=-1 means the last dim
        box_preds = torch.cat([xy, wh], dim=-1)
        ious = intersection_over_union(box_preds, targets[..., 1:5], GIoU=True).detach()
        # mse or bce??
        object_loss = self.bce(self.sigmoid(preds[..., 0:1][obj]), ious[obj] * targets[..., 0:1][obj])

        # ======================== #
        #   FOR BOX COORDINATES    #
        # ======================== #

        # preds[..., 1:3] = self.sigmoid(preds[..., 1:3])   # x,y coordinates
        # targets[..., 3:5] = ((targets[..., 3:5]/2)**(1/2))/anchors
        # width, height coordinates
        # box_loss = self.mse(preds[..., 1:5][obj], targets[..., 1:5][obj])

        box_loss = (1 - ious[obj]).mean()

        # ================== #
        #   FOR CLASS LOSS   #
        # ================== #

        class_loss = self.entropy(
            (preds[..., 5:][obj]), (targets[..., 5][obj].long()),

        return (
            self.lambda_box * box_loss
            + self.lambda_obj * object_loss
            + self.lambda_noobj * no_object_loss
            + self.lambda_class * class_loss

I know it’s quite a vague topic but I am having quite a hard time figuring out what’s going wrong.