I am trying to balance my data for a multi-classes classification task to get better scores, using weights and torch.utils.data.sampler.WeightedRandomSampler() I get an error that I don’t understand. Is there any other way to handle imbalanced classes easily ? Here is a snippet of my code and the error in hand:
.
train_set = SentimentDataset(file=TRAIN_DATA, word2idx=word2idx, tword2idx=tword2idx,
max_length=0, max_topic_length=0, topic_bs=True)
val_set = SentimentDataset(file=VAL_DATA, word2idx=word2idx, tword2idx=tword2idx,
max_length=0, max_topic_length=0, topic_bs=True)
_weights = torch.FloatTensor(train_set.weights) # train_set.weights : [296, 3381, 12882, 12857, 1016]
_weights = _weights.view(1, 5)
_weights = _weights.double()
sampler = torch.utils.data.sampler.WeightedRandomSampler(_weights, BATCH_SIZE)
loader_train = DataLoader(train_set, batch_size=BATCH_SIZE,
shuffle=False, sampler=sampler, num_workers=4)
loader_val = DataLoader(val_set, batch_size=BATCH_SIZE,
shuffle=False, sampler=sampler, num_workers=4)
model = RNN(embeddings, num_classes=num_classes, **_hparams)
model.cuda()
criterion = torch.nn.CrossEntropyLoss()
parameters = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(parameters)
# TRAIN
...
class SentimentDataset(Dataset):
def __init__(self, file, max_length, max_topic_length, word2idx, tword2idx, topic_bs):
...
self.data = [SocialTokenizer(lowercase=True).tokenize(x)for x in self.data]
self.topics = [SocialTokenizer(lowercase=True).tokenize(x) for x in self.topics]
self.label_encoder = preprocessing.LabelEncoder()
self.label_encoder = self.label_encoder.fit(self.labels)
self.label_count = Counter(self.labels)
self.weights = [self.label_count['-2'], self.label_count['-1'],
self.label_count['0'], self.label_count['1'],
self.label_count['2']]
...
def __getitem__(self, index):
sample, label, topic = self.data[index], self.labels[index], self.topics[index]
File "/home/kostas/anaconda3/envs/pytorch_env/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 40, in <listcomp>
samples = collate_fn([dataset[i] for i in batch_indices])
File "/home/kostas/Gitlab/SemTest/models/datasets.py", line 108, in __getitem__
sample, label, topic = self.data[index], self.labels[index], self.topics[index]
TypeError: list indices must be integers or slices, not torch.LongTensor