Hi everyone,
I want to use nn.NLLLoss()
to implement the Focal Loss function for a binary segmentation problem. The logits obtained from a network and binary targets (i.e., 0 for -ve class and 1 for +ve class) are used. Following cases summarizes my failed attempts (T_T)
# code for reproducibility
import torch
from torch import nn
log_prob = torch.rand((8, 1, 128, 128)).view(8*128*128, 1) # (N,C,d1,d2, ..., dK) --> (N*d1*...*dK,C)
target = (torch.rand((8, 1, 128, 128)) > 0.5).long().view(-1) # (N,d1,d2,...,dK) --> (N*d1*...*dK,)
Case 1: No weight OR weight for the positive class only.
I have no idea what caused this error since the target values are within [0, 1].
>> nn.NLLLoss()(log_prob, target)
OR
>> weight = torch.tensor([44.0])
>> nn.NLLLoss(weight)(log_prob, target)
"IndexError: Target 1 is out of bounds."
Case 2: Weights for the negative and positive classes.
This time tensors for the input data were kept the same and weights for both classes were passed. What I understood from the following error is nn.NLLLoss()
requires weight for each pixel index (i.e., Case 3.
)
>> weight = torch.tensor([1.0, 44.0]) # (-ve, +ve) class weights
>> nn.NLLLoss(weight=weight)(log_prob, target)
"RuntimeError: weight tensor should be defined either for all 1 classes or no classes but got weight tensor of shape: [2]"
Case 3: Weight per pixel.
This time the loss function again had problem with the weights.
>> weight = target * 44.0
>> nn.NLLLoss(weight=weight)(log_prob, target)
"RuntimeError: weight tensor should be defined either for all 1 classes or no classes but got weight tensor of shape: [131072]"
I will highly appreciate if you could explain how this weight
parameter is designed to work. A reference to the weight
or pos_weight
parameter in nn.BCEWithLogitsLoss()
would be a huge plus.
-----------------------------------------------------------------------
P.S. To randomly generate the binary target tensor, I first used target = torch.rand((8, 1, 128, 128), dtype=torch.bool)
which caused an error RuntimeError: "check_uniform_bounds" not implemented for 'Bool'
Isn’t the random generation of a bool type tensor supported?