How to train a system on a highly imbalanced data? Why is my system not learning?

I dont know if my data is unbalanced as I would imagine most of the MRI datasets are unbalanced.

I have 3D volumes of size [30, 512, 1024], I am targetting small spheres in this volume, I have a target tensor for each input. Since Each volume may have a small number of events (spheres) then the majority of the target tensor is black with the spheres of interest labeled as 1.

I assume this is how all binary segmentation volumes are labelled. Keep the background 0 and the areas we are interested as 1.

To iterate over each volume I have divided it into patches of size [30, 128, 128], so I would get a patch having nothing in it, for example for that patch my target patch would be all 0. But I noticed that for such patches I am consistently getting a very high loss and very low accuracy.

But for blocks where there is data in there, for example for patches where the target patch also shows my spheres location then I would get some sort of loss and accuracy, but would be very low.

This is the kind of output I am getting from my system, it looks like it is not learning at all with so many patches having no value of interest for me.

I wrote the input and target images to disk for each iteration and for Loss: 1 it would be the patches where there is nothing in the target patch.

Am I doing something wrong? Is my system learning?

idx:  0 of  312 - Training Loss:  1.0 - Training Accuracy:  3.204042239857152e-11
idx:  5 of  312 - Training Loss:  0.9876335859298706 - Training Accuracy:  0.0119545953348279
idx:  10 of  312 - Training Loss:  1.0 - Training Accuracy:  7.269467666715101e-11
idx:  15 of  312 - Training Loss:  0.7320756912231445 - Training Accuracy:  0.22638492286205292
idx:  20 of  312 - Training Loss:  0.3599294424057007 - Training Accuracy:  0.49074622988700867
idx:  25 of  312 - Training Loss:  1.0 - Training Accuracy:  1.0720428988975073e-09
idx:  30 of  312 - Training Loss:  1.0 - Training Accuracy:  1.19782361807097e-09
idx:  35 of  312 - Training Loss:  1.0 - Training Accuracy:  1.956790285362331e-09
idx:  40 of  312 - Training Loss:  1.0 - Training Accuracy:  1.6055999862985004e-09
idx:  45 of  312 - Training Loss:  1.0 - Training Accuracy:  7.580232552761856e-10
idx:  50 of  312 - Training Loss:  1.0 - Training Accuracy:  9.510597864803572e-10
idx:  55 of  312 - Training Loss:  1.0 - Training Accuracy:  1.341515676323013e-09
idx:  60 of  312 - Training Loss:  0.7165247797966003 - Training Accuracy:  0.02658153884112835
idx:  65 of  312 - Training Loss:  1.0 - Training Accuracy:  4.528208030762926e-09
idx:  70 of  312 - Training Loss:  0.3205708861351013 - Training Accuracy:  0.6673439145088196
idx:  75 of  312 - Training Loss:  0.9305377006530762 - Training Accuracy:  2.3437689378624782e-05
idx:  80 of  312 - Training Loss:  1.0 - Training Accuracy:  5.305786885401176e-07
idx:  85 of  312 - Training Loss:  1.0 - Training Accuracy:  4.0612556517771736e-07
idx:  90 of  312 - Training Loss:  0.8207412362098694 - Training Accuracy:  0.0344742126762867
idx:  95 of  312 - Training Loss:  0.7463213205337524 - Training Accuracy:  0.19459737837314606
idx:  100 of  312 - Training Loss:  1.0 - Training Accuracy:  4.863646818620282e-09
idx:  105 of  312 - Training Loss:  0.35790306329727173 - Training Accuracy:  0.608722984790802
idx:  110 of  312 - Training Loss:  1.0 - Training Accuracy:  3.3852198821904267e-09
idx:  115 of  312 - Training Loss:  1.0 - Training Accuracy:  1.5268487585373691e-09
idx:  120 of  312 - Training Loss:  1.0 - Training Accuracy:  3.46353523639209e-09
idx:  125 of  312 - Training Loss:  1.0 - Training Accuracy:  2.5878148582347826e-11
idx:  130 of  312 - Training Loss:  1.0 - Training Accuracy:  2.3601216467272756e-11
idx:  135 of  312 - Training Loss:  1.0 - Training Accuracy:  1.1504343033763575e-09
idx:  140 of  312 - Training Loss:  0.4516671299934387 - Training Accuracy:  0.13879922032356262

I read that for such datasets DiceLoss works much better and that is what I am using here. I am using a 3D implementation of the U-NET.

I’m very new to this and want to ask if I am on the right path?

  • Should I use some sort of weighted loss to factor in the imbalance in my dataset?

  • Should I increase the patch size so as there is a higher probability of points of interest being included in each patch?

  • Is this running as it is supposed to? And should I just increase the number of epochs to a very high number and just let it train overnight and hope that it learns?

I assume MRI scans where only small tumors are detected have similar datasets, is my understanding correct?

Many thanks

You could try a weighted loss using the inverse of the class frequencies, which might help if your model is overfitting.
However, I’m wondering why you’re getting a high loss for patches containing just background.
If your model overfits to the background class, these patches should be nearly perfectly classified.

However, I’m wondering why you’re getting a high loss for patches containing just background.
If your model overfits to the background class, these patches should be nearly perfectly classified.

Thanks for answering, this is what I dont understand as well hence me asking here.

I am using the model provided here: GitHub - wolny/pytorch-3dunet: 3D U-Net model for volumetric semantic segmentation written in pytorch. I am only using their provided implementation of the U-NET model along with their implementation of DiceLoss and Dicecoefficient, I have my own dataloader class for my dataset and my own file to train the model.

The system works as expected with the provided random data, but with my data I am getting the above results. I manually checked each input and label as well, and the inputs seem to be in order. The background only classes should not give such a high loss.

I dont know what to check for next at this stage, as I dont know what to do to fix this. =/

Yeah, your loss probably explodes, since it seems the class_weights will be really high in case you are just dealing with a background image (target_sum should be zero, so you are dividing by epsilon).

Here is a simple implementation from @IssamLaradji which uses a smoothing factor, which might counter this effect.

1 Like

Unfortunately the linked code and the smoothing factor does not help much either, as for a patch where all I have are background the tflat.sum() would be 0. This would make intersection 0 as well, thus for majority of my patches or blocks I will get a return of 1.

So now rather than getting a very high loss I am getting a consistent loss of 1 for almost all patches.

This is strange that people have solved similar situations with dice loss but its not working for me.


I would be glad you share your way through as I am currently facing the same challenge. Thanks