Imbalance in 3D binary segmentation

Hello ! (I apologize if I do something wrong here, I am not used to asking question on forums)

I am working with a model to segment brain tumors from 3D brain MRIs, here doing binary segmentation.
I created a U-net model taking the 3d object as whole. I do a bit of processing beforehand to read data approximately of size 1*150*220*220, 1 because I only have one channel.
The preprocessing trims black borders, giving images with variable size so I use a Fully Convolutional Network with batches of one images.

The problem is that the tumors represent a very small portion of the voxels: around 1/300, creating a great imbalance.
I tried to use the nn.BCEWithLogitLoss as my loss function, allowing to weight the different classes, but it does not seem to be working correctly, the model giving me only negative answers (no voxel is a tumor)

weights = torch.tensor([300.])
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=weights)

I tired to change the 300. to different values, hopping to see a result, or at least a variation (0.0001 or 50000) but nothing ever appeared.

Am I using the function wrong ? Is there an other loss function that could do what I want it to do ?

Once again I apologize if this seems confused or easy, I have little experience using pytorch, so I do everything based on what I can find online.
Thank you very much for your time and consideration.

Hi Owen!

This is indeed the correct approach to reweight your underrepresented
positive class (assuming that your tumor voxels are labelled as your
“positive” voxels).

The next step would be step-by-step debugging:

Read (or at least skim) through your code looking for typos or other bugs.

Make sure that you are not using a sigmoid() as a final “activation” in
your network. (You want to use the output of your final out_channels = 1
Conv3d layer as the predictions that are fed directly into your
BCEWithLogitsLoss loss criterion.)

Check how you are converting the output of your network into “hard,”
yes-no predictions for your voxels.

Look at your input images, output (prediction) images, and your label (mask)
images to make sure everything looks okay.

Run a single forward pass and check the output. Check the value of the
loss function with and without pos_weight.

Run a single backward pass. Do you get (reasonable) gradients?

Run a single optimizer step. Do your weights change?

Run another forward pass on the same input image. Does the output of
your model change?

If all of this seems to be working, see if you can overfit a small number
of input images by running many training epochs over and over again on
the same input images. Does the loss go down (even if slowly)? Does
the “probability” of a predicted positive voxel go up, even if slowly, and
even if it doesn’t cross the threshold you’ve chosen to covert the “soft,”
probabilistic prediction into a “hard,” yes-no prediction?

And so on …

Good luck!

K. Frank

Hello and thank you so much for helping !

I have done step-by-step debugging, analyzing the output of my model before any sigmoid function.

I came to see on a small set (5) of images, that the prediction of a default(not trained) have an average around 0.0. Training the model a few epochs (10 at most) on thoses images results in the range going down, having no value above 0 (so no voxel counted as positive). The range of results has been for example [-100, -1]. But the iterations do change the output of my network so there is indeed a training

Changing the pos_weight using either 300 or 1 changes the loss of the examples by aroung 10^-4 on a default model (from 0.6872... without the pos_weight to 0.6874... with it)

I apologize, I am not sure how to interpret thoses results, if they should be considered ad “working” or not. There is indeed a training, but it goes in the wrong direction :confused:

Hi Owen!

Such values for BCEWithLogitsLoss in the context you describe look a
little odd to me.

Could you tell me in words what the last three operations your network
performs on the values it outputs as its predictions are? Please also then
post the last three layers of your network.

Please post the line where you instantiate your optimizer and the
section of your code from predictions = model (input) through
optimizer.step().

Print out the shape, type, min, and max of one of your 3d input images, the
shape, type, min, and max of the output your network predicts for it, together
with the shape, type, min, and max of its corresponding ground-truth mask.

Best.

K. Frank

Hello K. Frank !

The last three operations are a LeakyReLU and 2 3D Convolutions with a kernel of size=5
such ass

nn.LeakyReLU(),
nn.Conv3d(in_channels=16, out_channels=16, kernel_size=5, padding=2),
nn.Conv3d(in_channels=16, out_channels=1, kernel_size=5, padding=2),

as for the code itself

lr=1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
weights = torch.tensor([300.]).to(device)
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=weights, reduction='mean')
def epoch_train_all(model: Baseline, optim, loss_fn, device, batch_size:int = 5, path: str = "./data/train/"):
    optim.zero_grad()
    avg_loss = 0.0
    nb_files = 0
    model.to(device)
    for id, dm in get_files_it(path):
        nb_files += 1
        print(f"\r\tTRAIN file nb{nb_files:03d}, id:{id}", end=" ")
        # transferring data from cache to device
        data = dm.data.to(device)
       


        pred = model(data)
        del data
        mask = dm.mask.to(device)
        loss = loss_fn(pred, mask)
        del pred
        del mask
        # once the data has been used, free it to free some space on the device
        loss.backward()
        avg_loss += float(loss)
        del loss
        if nb_files % batch_size == 0:
          
            print("optim !", end='')
            optim.step()
            optim.zero_grad()
            return




    optim.step()
    optim.zero_grad()
    avg_loss /= nb_files
    print(f"avg. train loss: {avg_loss}")
    return avg_loss

I have isolated in the middle the part that you asked for.
I will add a few precisions:

  • I use gradient accumulation, so my batch size can control the frequency of my backward pass
  • there is here a return after the 5 items in my batch to allow fitting on only the first 5 elements of my data (I have disabled the shuffling)
  • get_files_it(path): is an iterator I have made which returns my processed data in a class I have created
  • I delete the elements once they are not used anymore to save some memory

As for the statistics:
data: input data
prediction: the output value of my network
pred w/ sig: the previous prediction passed into a sigmoid (not in the network because working bit BCEWithLogitsLoss)
ground truth: the expected mask

1st epoch:

data:        	shape=torch.Size([1, 1, 248, 88, 176]), type: torch.cuda.FloatTensor, range:[0.0; 1.0]
prediction: 	shape=torch.Size([1, 1, 248, 88, 176]), type: torch.cuda.FloatTensor, range:[-0.016314834356307983; 0.0024049831554293633]
pred w/ sig:	shape=torch.Size([1, 1, 248, 88, 176]), type: torch.cuda.FloatTensor, range:[0.49592143297195435; 0.5006012320518494]
ground truth:	shape=torch.Size([1, 1, 248, 88, 176]), type: torch.cuda.FloatTensor, range:[0.0; 0.0021929824724793434]

10th epoch

data:       	shape=torch.Size([1, 1, 248, 88, 176]), type: torch.cuda.FloatTensor, range:[0.0; 1.0]
prediction: 	shape=torch.Size([1, 1, 248, 88, 176]), type: torch.cuda.FloatTensor, range:[-269.3184814453125; -1.7151740789413452]
pred w/ sig:	shape=torch.Size([1, 1, 248, 88, 176]), type: torch.cuda.FloatTensor, range:[0.0; 0.15249381959438324]
ground truth:	shape=torch.Size([1, 1, 248, 88, 176]), type: torch.cuda.FloatTensor, range:[0.0; 0.0021929824724793434]

Thank you very much for your time and help, I appreciate :relaxed:

Hi Owen!

Thank you for the details.

Assuming that “ground truth” is the mask that you pass into loss_fn(),
this is likely your problem. BCEWithLogitsLoss expects it target (your
mask) to be the probability of the corresponding item to be in the “positive”
class. Based on the range you posted, some of your mask values are
0.0 which means definitely in the “negative” class (that is, 0% probability
of being in the “positive” class). However, your largest mask value is
0.00219 which means that the corresponding item has only a 0.219%
probability of being in the “positive” class which means that it is almost
certainly
in the “negative” class (in the “negative” class with a probability
of 99.78%).

Therefore your ground-truth mask is saying that all of your voxels are
either certainly “negative” or almost certainly “negative.” So you are
(successfully) training your network to always predict “negative” voxels
(regardless of the value of pos_weight, because, according to mask,
you don’t have any positive voxels).

I don’t know where your mask comes from, but could it be that 0.0 means
“negative” while any non-zero value means “positive?” You would want
“positive” voxels to be labelled with a value equal to, or at least reasonably
close to 1.0.

Regardless of the scale of the values, is mask supposed to contain hard
“negative”-or-“positive” labels, or are the values to be understood as soft,
probabilistic labels?

Best.

K. Frank

Hello KFrank,

I feel so stupid for not seeing this.
I was applying my normalization wrong on my masks (they come with 0 as the negative value and 2 as the positive value). I was using the normalization I apply on my data, giving stupidly low results. I am now trying with the corrected data.

The masks are boleans, they should be either yes-or-no values per voxels.

Thank you very much,
Owen