I am dealing with extremely imbalance data (neg 99%, pos 1%).
I used a WeightedRandomSampler
to solve this problem.
Below is the WRS code I used.
neg=0.0
pos=0.0
targets = []
for _, target in traindataset:
targets.append(target.max())
if target.max()==1:
pos+=1
else:
neg+=1
weights = neg/pos
pos_weight = torch.tensor(weights)
targets = torch.stack(targets).long()
class_sample_count = [neg,pos]
weights = 1/torch.tensor(class_sample_count, dtype=torch.float)
samples_weights = [weights[t] for t in targets]
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights=samples_weights, num_samples=len(samples_weights), replacement=True)
data_loader = torch.utils.data.DataLoader(traindataset, batch_size = batch_size, shuffle=False,sampler=sampler, num_workers=0, pin_memory=False)
When using WRS, the train accuracy exceeded 90% in 4 epochs, but the valid accuracy did not exceed 10%.
In order to know the cause, I checked the train image with WRS applied in each batch, and confirmed that the duplicate of pos data was serious because it was extremely imbalanced.
Because of this overfitting, I wonder if there is a smapler that undersampling neg data, leaving pos data intact.