Hi everyone, I’m dealing a multilabel segmentation problem where I have 2 classes - Background class, Object class. The objects are so small so it means there is a class imbalance in data. Thus, I decided to use BceLossLogit with pos_weight.
- For the first epoch the training loss was so high - around 7. Is it normal to start with such a high loss - it only happens when I use pos_weight parameter.
- From the second epoch the loss became normal but after that the validation loss did not decrease. In general what is the best case practice to choose a loss function in such tasks?
The model outputs (also the target images) have this shape : [Batch, Number of Classes, H, W ]
def _calculate_positive_weight(self, label_images) -> torch.tensor:
class_counts = label_images.sum(dim=(0, 2, 3))
total_count = label_images.numel() / label_images.size(1)
positive_weight = total_count / (2 * class_counts[1].float())
positive_weight_tensor = torch.tensor([positive_weight]).to(self.device)
return positive_weight_tensor
def _initialize_criterion(self):
for batch in self.training_data_loader:
label_images = batch["label_image_tensor"].float().to(self.device)
positive_weight_tensor = self._calculate_positive_weight(label_images)
self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=positive_weight_tensor)
break
I appreciate your support.