Hi everyone,
I am working on a one-hot encoded dataset (including about 2000 images of the retina with each of which having one or more labels (multi-label classification). There are 28 labels in total, and as I said are represented in a one-hot coding structure.
For instance : the target [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] means that the retinal image has class 0 (diabetic retinopathy) and class 16 (macular degeneration) labels.
Importantly, the classes are imbalanced. To my knowledge, there are two ways to deal with this problem: 1) Calculating the loss function with respect to the class weights (as discussed in Multi-Label, Multi-Class class imbalance - #2 by ptrblck, I used BCELOSSWithlogits, setting the “reduction” parameter as “none”, and multiplying the loss with the class weights.
2) Oversampling using WeightedRandomSampler
Actually here is my problem. As the targets are in a one-hot coded format, how can we calculate the sample_weights tensor? I have read the post Class imbalance with image segmentation - #3 by An18 completely but still am not sure whether the right solution was provided there or not (in reply to An18, it was said that sample weight could be calculated as “samples_weights = y*weights.T”. Also, if I calculate the sample weights in this way and create my_sampler as:
my_sampler = torch.utils.data.WeightedRandomSampler (sample_weights, num_samples= len (sample_weights), replacement= True)
and then passing it to the sampler parameter of the DataLoader, an error raises:
“invalid multinomial distribution (sum of probabilities <= 0)”
Would calculating the sample_weights in the following way be the right solution?
sample_weights = torch.mean (torch.mul (targets, class_weights.T), 1)
I am new to PyTorch and deeply appreciate your solutions.
Here is the relevant part of my code:
class_counts = torch.tensor ([376, 100, 317, 138, 101, 73, 186, 14, 47, 15, 37, 282, 28, 6, 16, 65, 58, 5, 17, 11, 14, 43, 32, 15, 22, 11, 6, 34])
class_weights = torch.tensor (1./ class_counts, dtype = torch.float)
sample_weights = torch.mean (torch.mul (targets, class_weights.T), 1)
my_sampler = torch.utils.data.WeightedRandomSampler (sample_weights, num_samples= len (sample_weights), replacement= True)
train_loader = DataLoader (train_dataset, batch_size = 128, sampler= my_sampler)
next (iter (train_loader2))
The last line is where the error raises if I do not calculate the mean of “torch.mul (targets, class_weights.T), 1”.