How to apply a weighted BCE loss to an imbalanced dataset? What will the weight tensor contain?

Hi,

There have been previous discussions on weighted BCELoss here but none of them give a clear answer how to actually apply the weight tensor and what will it contain?

I’m doing binary segmentation where the output is either foreground or background (1 and 0). But my dataset is highly imbalanced and there is way more background than foreground. (To be exact there is 95 times more background pixels than foreground). So I want to penalize the background by multiplying it with a small number.

I see that BCELoss has a weight parameter: https://pytorch.org/docs/stable/nn.html which according to docs is to be of the size nbatch.

So if I have a batch of size 1, would the eight tensor be of size 1 as well? wont that just be a single float value?

ive read the discussion here: Binary cross entropy weights but that does not answer what the weight tensor would look like for a batch size of 1?

Sorry if I’m rambling but I can’t seem to find anywhere how to use the weight parameter properly. If my output and label tensors are of the size [1, 1, 60, 40, 40] how would I make my weight parameter to penalize the background class (0)?

Hi phdproblems!

Let me add some hypothetical context to your question to make it
more concrete:

As you say, you have a batch size of one. Each batch tensor is
of torch.Size([1, 1, 60, 40, 40]), so each sample is of
torch.Size([1, 60, 40, 40]).

Let me imagine that each sample represents, say, a one-channel
(e.g., black and white), three-dimensional image of 60 z-slices
(or time-slices), each of x-y dimension of 40x40. The output
of your network is a tensor for which each element is the predicted
probability (between 0 and 1) that the corresponding voxel (3-d
pixel) in your 3-d image is a foreground voxel. I will call
this your predict tensor.

Your target tensor is of the same shape, and is the known
training data with values of 0 or 1 indicating background and
foreground voxels, respectively.

(That is, your network performs a binary classification of your
voxels. As far as the loss is concerned, you don’t care that
the voxels happen to be arranged as a (one-channel) 60x40x40
3-d image.)

To weight each voxel’s contribution separately, you want a weight
tensor that is the same shape as your predict and target tensors,
i.e., torch.Size([1, 1, 60, 40, 40]). (So its first dimension is
your batch size of one.)

The following script (for pytorch 0.3.0) illustrates this:

import torch
print (torch.__version__)

torch.manual_seed (2019)

predict = torch.rand ((1, 1, 60, 40, 40))
target = torch.bernoulli (predict)
weight =  torch.rand ((1, 1, 60, 40, 40))

weight_rebal = torch.ones_like (target) / 95.0  +  (1.0 - 1.0 / 95.0) * target

predict_f = predict.clone()
target_f = target.clone()
weight_f = weight.clone()
predict_f.resize_((96000, 1))
target_f.resize_((96000, 1))
weight_f.resize_((96000, 1))

predict = torch.autograd.Variable (predict)
target = torch.autograd.Variable (target)

predict_f = torch.autograd.Variable (predict_f)
target_f = torch.autograd.Variable (target_f)

loss = torch.nn.functional.binary_cross_entropy (predict, target)
lossw = torch.nn.functional.binary_cross_entropy (predict, target, weight)
loss_rebal = torch.nn.functional.binary_cross_entropy (predict, target, weight_rebal)

loss_f = torch.nn.functional.binary_cross_entropy (predict, target)
loss_fw = torch.nn.functional.binary_cross_entropy (predict_f, target_f, weight_f)

print ('predict.shape =', predict.shape)
print ('loss =', loss)
print ('lossw =', lossw)
print ('loss_rebal =', loss_rebal)
print ('predict_f.shape =', predict_f.shape)
print ('loss_f =', loss_f)
print ('loss_fw =', loss_fw)

The output is:

0.3.0b0+591e73e
predict.shape = torch.Size([1, 1, 60, 40, 40])
loss = Variable containing:
 0.5002
[torch.FloatTensor of size 1]

lossw = Variable containing:
 0.2503
[torch.FloatTensor of size 1]

loss_rebal = Variable containing:
 0.2530
[torch.FloatTensor of size 1]

predict_f.shape = torch.Size([96000, 1])
loss_f = Variable containing:
 0.5002
[torch.FloatTensor of size 1]

loss_fw = Variable containing:
 0.2503
[torch.FloatTensor of size 1]

weight is a tensor of random weights (one for each voxel).

weight_rebal is a tensor of foreground-background weights
(again, one for each voxel), computed from your training-data
target, where background voxels are given a weight factor of
1/95 (to account for their greater frequency of occurrence).

Last, the _f (for flattened) tensors and losses are just to show
that the shape doesn’t affect the per-voxel loss computation.
These can be understood, if you will, as consisting of a batch
of 96,000 samples (batch size = 96,000) of single floating-point
prediction values and single 0 or 1 class labels.

(As a side note, CrossEntropyLoss permits you to specify a
per-class weight, e.g., one weight for foreground and another
weight for background voxels, rather than having to pass in a
tensor of per-voxel weights. You could recast your problem
as a two-class (foreground / background) multiclass classification
problem and use CrossEntropyLoss, but I wouldn’t recommend it.)

Good luck.

K. Frank

5 Likes

Thank you so much. I think this might be the first detailed response on this question here. And for once cleared up the weight tensor size and shape confusion. I really appreciate the time you took in this detailed answer, cleared up all of my confusions :heart:

Can you please mention why you would not recommend this, I though BCE is just CE but for a binary problem, is it not advised to use CE for binary problems?

Thanks again.

1 Like

hi another followup question to this.

How you explained in your answer is exactly what I need, i.e. calculate the weight tensor for each instance based on the target tensor. But from what i have read, pytorch does not support this, it only supports the same weight for all instances in a batch which has to be provided when the loss is declared/initialized.

Is this the case or can I provide a different weight for each instance?

Thank you

Hello phdproblems!

BCE takes a single number per sample for its prediction – the
probability of the sample being in class “1”.

Multiclass CE takes, for N classes, N numbers for its prediction.
These are the probabilities of the sample being in each of the
N classes. (Pytorch’s CrossEntropyLoss has a softmax
implicitly built in, so it takes logits, rather than probabilities.)

So if you use multiclass CE for a binary (two-class) problem,
you will now have to pass two numbers to it instead of one.
(If these are official probabilities, they will be redundant, in
that they will sum to 1.) So your network will have to have
two (redundant) outputs instead of one. Hardly a big deal,
but it would presumably make your network ever so slightly
less efficient.

If you do everything correctly, you will be doing exactly the
same problem and will get exactly the same result (up to
some round-off error), but it just seems cleaner to me to
use BCE for the binary problem.

Best.

K. Frank

1 Like

Hi phdproblems!

Well, the documentation for the weight argument in BCELoss
and binary_cross_entropy is, to put it nicely, somewhat abbreviated.

I purposely used binary_cross_entropy in my example,
because you can pass in a batch of weights (together with
your predict and target) every time the loss is called.

(As you note, with BCELoss you pass in the weight only at
the beginning when you instantiate the BCELoss class, so
you can’t give it different weights every time you call it with a
new predict and target.)

Also in my example I showed that passing in per-voxel weights
to binary_cross_entropy does indeed work, even if the
documentation doesn’t make this clear, and showed that this
gives the same result as the “flattened” version (predict_f,
target_f, weight_f) where a batch of weights – one weight
for each sample in the batch – is passed in, consistent with
the simplest interpretation of the documentation.

Good luck.

K. Frank

1 Like

You’re right, sorry I missed that you were using binary_cross_entropy() and not BCELoss() (which I am using).

Thanks for clarifying. I’ll just use binary_cross_entropy() :slight_smile:

Thanks for your help