How to train a system on a highly imbalanced data? Why is my system not learning?

I dont know if my data is unbalanced as I would imagine most of the MRI datasets are unbalanced.

I have 3D volumes of size [30, 512, 1024], I am targetting small spheres in this volume, I have a target tensor for each input. Since Each volume may have a small number of events (spheres) then the majority of the target tensor is black with the spheres of interest labeled as 1.

I assume this is how all binary segmentation volumes are labelled. Keep the background 0 and the areas we are interested as 1.

To iterate over each volume I have divided it into patches of size [30, 128, 128], so I would get a patch having nothing in it, for example for that patch my target patch would be all 0. But I noticed that for such patches I am consistently getting a very high loss and very low accuracy.

But for blocks where there is data in there, for example for patches where the target patch also shows my spheres location then I would get some sort of loss and accuracy, but would be very low.

This is the kind of output I am getting from my system, it looks like it is not learning at all with so many patches having no value of interest for me.

I wrote the input and target images to disk for each iteration and for Loss: 1 it would be the patches where there is nothing in the target patch.

Am I doing something wrong? Is my system learning?

idx:  0 of  312 - Training Loss:  1.0 - Training Accuracy:  3.204042239857152e-11
idx:  5 of  312 - Training Loss:  0.9876335859298706 - Training Accuracy:  0.0119545953348279
idx:  10 of  312 - Training Loss:  1.0 - Training Accuracy:  7.269467666715101e-11
idx:  15 of  312 - Training Loss:  0.7320756912231445 - Training Accuracy:  0.22638492286205292
idx:  20 of  312 - Training Loss:  0.3599294424057007 - Training Accuracy:  0.49074622988700867
idx:  25 of  312 - Training Loss:  1.0 - Training Accuracy:  1.0720428988975073e-09
idx:  30 of  312 - Training Loss:  1.0 - Training Accuracy:  1.19782361807097e-09
idx:  35 of  312 - Training Loss:  1.0 - Training Accuracy:  1.956790285362331e-09
idx:  40 of  312 - Training Loss:  1.0 - Training Accuracy:  1.6055999862985004e-09
idx:  45 of  312 - Training Loss:  1.0 - Training Accuracy:  7.580232552761856e-10
idx:  50 of  312 - Training Loss:  1.0 - Training Accuracy:  9.510597864803572e-10
idx:  55 of  312 - Training Loss:  1.0 - Training Accuracy:  1.341515676323013e-09
idx:  60 of  312 - Training Loss:  0.7165247797966003 - Training Accuracy:  0.02658153884112835
idx:  65 of  312 - Training Loss:  1.0 - Training Accuracy:  4.528208030762926e-09
idx:  70 of  312 - Training Loss:  0.3205708861351013 - Training Accuracy:  0.6673439145088196
idx:  75 of  312 - Training Loss:  0.9305377006530762 - Training Accuracy:  2.3437689378624782e-05
idx:  80 of  312 - Training Loss:  1.0 - Training Accuracy:  5.305786885401176e-07
idx:  85 of  312 - Training Loss:  1.0 - Training Accuracy:  4.0612556517771736e-07
idx:  90 of  312 - Training Loss:  0.8207412362098694 - Training Accuracy:  0.0344742126762867
idx:  95 of  312 - Training Loss:  0.7463213205337524 - Training Accuracy:  0.19459737837314606
idx:  100 of  312 - Training Loss:  1.0 - Training Accuracy:  4.863646818620282e-09
idx:  105 of  312 - Training Loss:  0.35790306329727173 - Training Accuracy:  0.608722984790802
idx:  110 of  312 - Training Loss:  1.0 - Training Accuracy:  3.3852198821904267e-09
idx:  115 of  312 - Training Loss:  1.0 - Training Accuracy:  1.5268487585373691e-09
idx:  120 of  312 - Training Loss:  1.0 - Training Accuracy:  3.46353523639209e-09
idx:  125 of  312 - Training Loss:  1.0 - Training Accuracy:  2.5878148582347826e-11
idx:  130 of  312 - Training Loss:  1.0 - Training Accuracy:  2.3601216467272756e-11
idx:  135 of  312 - Training Loss:  1.0 - Training Accuracy:  1.1504343033763575e-09
idx:  140 of  312 - Training Loss:  0.4516671299934387 - Training Accuracy:  0.13879922032356262

I read that for such datasets DiceLoss works much better and that is what I am using here. I am using a 3D implementation of the U-NET.

I’m very new to this and want to ask if I am on the right path?

  • Should I use some sort of weighted loss to factor in the imbalance in my dataset?

  • Should I increase the patch size so as there is a higher probability of points of interest being included in each patch?

  • Is this running as it is supposed to? And should I just increase the number of epochs to a very high number and just let it train overnight and hope that it learns?

I assume MRI scans where only small tumors are detected have similar datasets, is my understanding correct?

Many thanks

You could try a weighted loss using the inverse of the class frequencies, which might help if your model is overfitting.
However, I’m wondering why you’re getting a high loss for patches containing just background.
If your model overfits to the background class, these patches should be nearly perfectly classified.

However, I’m wondering why you’re getting a high loss for patches containing just background.
If your model overfits to the background class, these patches should be nearly perfectly classified.

Thanks for answering, this is what I dont understand as well hence me asking here.

I am using the model provided here: https://github.com/wolny/pytorch-3dunet. I am only using their provided implementation of the U-NET model along with their implementation of DiceLoss and Dicecoefficient, I have my own dataloader class for my dataset and my own trainer.py file to train the model.

The system works as expected with the provided random data, but with my data I am getting the above results. I manually checked each input and label as well, and the inputs seem to be in order. The background only classes should not give such a high loss.

I dont know what to check for next at this stage, as I dont know what to do to fix this. =/

Yeah, your loss probably explodes, since it seems the class_weights will be really high in case you are just dealing with a background image (target_sum should be zero, so you are dividing by epsilon).

Here is a simple implementation from @IssamLaradji which uses a smoothing factor, which might counter this effect.

1 Like

Hi,
Unfortunately the linked code and the smoothing factor does not help much either, as for a patch where all I have are background the tflat.sum() would be 0. This would make intersection 0 as well, thus for majority of my patches or blocks I will get a return of 1.

So now rather than getting a very high loss I am getting a consistent loss of 1 for almost all patches.

This is strange that people have solved similar situations with dice loss but its not working for me.

Hi,

I would be glad you share your way through as I am currently facing the same challenge. Thanks