I am trying to get answer to this problem but all solution only discusses multiclass so I thought to ask.
I have image with pixel values 0 or 1 as a ground truth of n*n size. The only problem is the ratio of pixel value of 1 to pixel value of 0 is 0.00005 for any given ground truth image. So the data is highly unbalanced for the binary ground truth image. As expected the model is predicting a lot of pixel value of 0 while attaining small loss value and not predicting any pixel values with value of 1. Is there a way I can ask BCEWITHLOGITSLOSS to weight the loss function to take into account of this imbalanced. If so how to do this?
The documentation about BCEWITHLOGITSLOSS talks about pos_weight for positive and negative class but I am not sure I understood what it means.
Great question. The docs aren’t very clear with the example given. However, they do at least say
For example, if a dataset contains 100 positive and 300 negative examples of a single class, then pos_weight for the class should be equal to 300/100=3. The loss would act as if the dataset contains 3×100=300 positive examples.
From this, we can deduce that if we put a pos_weight=3 for each target 1 in the above example, we’d be at unity. That also implies that if we put a pos_weight=0.3333 for each target 0, we’d also be at unity.
So here is a function you can try, and please let me know how it goes:
def get_pos_weight(target_batch):
x=target_batch.sum(dim=0)
out = target_batch*(target_batch.size(0)/x-1) #ratio of 1s to 0s, based on 0s set to weight of 1
out[out==0]=1 #set all 0s to weight of 1
out[torch.isnan(out)]=1 #this deals with the case when a batch has all zeros on that class
return out
Simply feed it your target batch, and it should* give you the positional weights.
Here is an example:
torch.manual_seed(0)
target = torch.round(torch.rand(10,5)) # batch_size = 10 and classes = 5
off_balance_targs = torch.cat([torch.ones(10,1), torch.zeros( 10,4)], dim=1)
# ^^^ let's double the number of 1s in the first output and 0s in the rest
# that will make our batch_size = 20 with targets that are both randomized and intentionally off balance
target = torch.cat([target, off_balance_targs], dim=0)
#get pos_weight
pos_weight=get_pos_weight(target)
print(pos_weight)
output = torch.randn((20,5)) # A model prediction (logit), randn so it's raw logits(negative and positive)
# calculate the loss
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
loss = criterion(output, target)
print(loss)
*I have not actually tested this on an unbalanced dataset, so please let me know if it works out.
Thank you for the reply, I will test and update the post as I see the training result.
That being said can you clarify this:
Let’s assume x ~ [B,C,H,W]
x=target_batch.sum(dim=0). This computes the sum along batch size giving us [C1,H1,W1] + [C2,H2,W3] + … + [Cn,Hn,Wn] through the whole batch
But how does target_batch*(target_batch.size(0)/x-1) give us the ratio of 1s to 0s.
Can you elaborate on this?
one way you can go is to use focal loss
here is a simple implementation of it:
class binaryFocalCrossentropy(nn.Module):
def __init__(self, from_sigmoid=False, apply_class_balancing=True, alpha=0.25, gamma=2.0, **kwargs):
super(binaryFocalCrossentropy, self).__init__()
self._alpha = alpha
self._gamma = gamma
self._from_sigmoid = from_sigmoid
self._apply_class_balancing = apply_class_balancing
def forward(self, yhat, y):
if not self._from_sigmoid:
yhat = torch.sigmoid(yhat)
BCE_loss = F.binary_cross_entropy_with_logits(yhat, y, reduction='none')
pt = torch.exp(-BCE_loss)
loss = self._alpha * (1-pt)**self._gamma * BCE_loss
if self._apply_class_balancing:
loss = loss * (y * self._alpha + (1 - y) * (1 - self._alpha))
return loss.mean()
alpha which here is set to 0.25 is used to balance the classes and gamma is to distinguish between easy and hard examples by down weighting easier and weighting more on hard examples.
We start with the assertion that targets are all 0s or 1s. Thus the sum of 1s elementwise is given by summing on the batch dimension:
x=target_batch.sum(dim=0)
Conversely, the sum of 0s elementwise is given by:
y=(1-target_batch).sum(dim=0)
Also, batch_size = x+y, when applied elementwise.
But we want the ratio of y/x, that is the ratio of 0s over 1s. Since y = batch_size - x, we can substitute that in for y:
ratio = (batch_size - x)/x
And a little more algebra gives the final result:
ratio = batch_size/x - 1
And then multiplying that by the target_batch will set any elements with 1s to the ratio of 0s to 1s, since we need the inverse weight in order to balance out the recall. And anywhere there are 0s, we should set to a weight of 1.
Now, I am not clear if the Pytorch function needs those 0s set to 1 or kept at their inverse ratio. You could try it both ways. If the latter, then remove out[out==0]=1 and change out=torch.ones_like(target_batch)*(target_batch.size(0)/x-1).
Also, you might need to update this line if you get any inf loss values out[torch.isnan(out)|torch.isinf(out)]=1.
I’m also not clear if BCEWithLogitsLoss needs the inputs flattened. If it does, the provided function for pos_weights is generalized enough to take in either shape(i.e. flattened vs. not).