Overlapping Issues

I am working on synthetic 3D medical image dataset with mainly 2 classes with spatial dimension 512 * 512 and varying depth i.e y * 512*512. My approach is to train with 2D Unet and generated the background by stacking the one-hot encoded target of the corresponding pixels of each class - making 1’s for the background pixel value where both classes are zero and 0’s where either of the classes is one.

The data is highly imbalanced and there are overlaps with the main classes except for the background considering the way is generated. Having trained severally with unreasonable results - I mean the model not learning, I gathered insights from similar and related discussions on the forum which revealed that I am dealing with multi-class multilabel segmentation. To this, I used BCEWithLogitsLoss passing the logits and stacked one-hot encoded target maintaining dimension of nBatch, nClass, Width, Heights using a batch size of 8 which resulted in a decreasing loss i.e the model is training.

I would be glad for your clarification and explanation regarding:

  • choice of weighted loss function - which is applicable pos_weight or weight of the BCEWithLogitsLoss, can other weighted loss functions be considered
  • Resizing the spatial dimension for the image and target as I am short of computational resources
  • Handling the evaluation metrics as currently using the DiceMetric from monai.
  • other helpful insights and suggestion