Unbalanced Dataset for YOLOv3 implementation

Hello Guys,

i am trying to implement the YOLOv3 architecture in Pytorch. You can find a quick recap for yolo here. Implementation of YOLOv3 in Pytorch

I took this implementation more or less and tried a refinement training with a custom dataset after pretraining the model. Now i came to a point where my Precisions tend to bias on the class with the highest amount of pictures in the training set. (class0: 2315 labels, class1: 867 labels, class2: 335 labels) As i am doing detection which is a multilabel problem, i can not simply use WeightedRandomSampler. So i intended to do a weighted loss implementation, to decrease the influence of the class0 as it has a great class influence.

Now the problem is, that in the Implementation they are using BCEWithLogitsLoss() as loss function where they write that the weight argument only weights the images per batch. What i need is a loss, with which i can penalty bigger datasets.

For Clearification a Code Snipped which shows what i wand to do.

class YOLOLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
        self.bce = nn.BCEWithLogitsLoss()
        self.entropy = nn.CrossEntropyLoss()
        self.sigmoid = nn.Sigmoid()
        # Constants signifying how much to pay for each respective part of the loss
        self.lambda_class = 1
        self.lambda_noobj = 4
        self.lambda_obj = 2
        self.lambda_box = 8

    def forward(self, preds, target, anchors):
        """Copmute yolo loss

        Arguements:
            preds (tensor): tensor of shape (N, 3, S, S, 5+C)
            target (tensor): tensor of shape (N, 3, S, S, 6)
            anchors (tensor): tensor of shape (3, 2)

        Prediction format:
            (x_raw, y_raw, w_raw, h_raw, conf, [classes...])

        Target format:
            (x_offset, y_offset, w_cell, h_cell, conf, class)
        """
        device = preds.device
        # Normalize factor
        scale = preds.size(2)
        # target with -1 is ignored
        obj = target[..., 4] == 1      # (N, 3, S, S)
        noobj = target[..., 4] == 0    # (N, 3, S, S)
        
        # NO OBJECT LOSS
        # ==========================================
        noobj_loss = self.bce(
                        preds[..., 4:5][noobj],
                        target[..., 4:5][noobj],
                        )