Hello,
I’m currently working on a NLP multi-class classification problem, where I have an unbalanced dataset.
After some research, I found out that using WeightedRandomSampler
, I could avoid the problem of always having the same biggest class being trained and predicted over and over again, with only sometimes other classes showing up.
( this wrong training leads to Val Loss going up, Train Loss and a high Val Accuracy going Down)
My question is the following: Is it correct/fair to apply the WeightedRandomSampler
in the TEST and validation datasets? Or it shouldn`t make any difference at all?
Here is how I do my dataset
from torch.utils.data import WeightedRandomSampler
# create dataset from numpy
train_dataset = TensorDataset(tensor_x_train,tensor_y_train)
valid_dataset = TensorDataset(tensor_x_valid,tensor_y_valid)
test_dataset = TensorDataset(tensor_x_test,tensor_y_test)
# Calculates weights of each SAMPLE
weights = 1. / torch.tensor(class_sample_count, dtype=torch.float)
weights = weights.double()
# Apply weights to each sample
sample_weight_train = weights[tensor_y_train]
sample_weight_test =weights[tensor_y_test]
sample_weight_val =weights[tensor_y_valid]
# create samplers
sampler_train = WeightedRandomSampler(
weights=sample_weight_train,
num_samples=len(sample_weight_train),
replacement=True)
sampler_test = WeightedRandomSampler(
weights=sample_weight_test,
num_samples=len(sample_weight_test))
sampler_val = WeightedRandomSampler(
weights=sample_weight_val,
num_samples=len(sample_weight_val))
# create Dataloader
train_dataloader = DataLoader(train_dataset,
batch_size=BATCH_SIZE,
sampler = sampler_train)
valid_dataloader = DataLoader(valid_dataset,
batch_size=BATCH_SIZE,
sampler = sampler_val)
test_dataloader = DataLoader(test_dataset,
batch_size=BATCH_SIZE,
sampler = sampler_test)
Now this code is how I measure the accuracy of my model in TEST set.
with torch.no_grad():
for x_test, y_test in test_dataloader:
y_pred = model(x_test)
loss = criterion(y_pred, y_test)
acc = binary_accuracy(y_pred, y_test) # Will return a number from 0.0 to 1.1
epoch_loss += loss.item()
epoch_acc += acc.item()
print("LOSS: ", epoch_loss/len(val_iter))
print("ACC: ",epoch_acc/len(val_iter))