I’m trying to implement stratified sampling in my dataset and calculate the weights for each sample in the dataset with the following function:
def make_weights_for_balanced_classes(labels):
unique_labels, counts = np.unique(labels, return_counts=True)
weight_per_class = np.sum(counts) / counts
weights = [0] * len(labels)
for i, val in enumerate(labels):
weights[i] = weight_per_class[np.where(unique_labels == val)[0]]
return weights
sampler = WeightedRandomSampler(weights, len(weights))
dataloader = DataLoader( dataset, batch_size=128, sampler=sampler )
But when I’m enumerating through the dataloader
the error occurs in my custom dataset
__getitem__
:
list indices must be integers or slices, not list
I wanted to know if this is normal that you’ll get a list of indices instead of a single index when using WeightedRandomSampler
. Should I change my dataset
to accept the list of indices whenever I’m using WeightedRandomSampler
?