WeightedRandomSampler with multi-dimensional batch

I’m working on a classification problem (100 classes) and my dataset has a huge class imbalance. To tackle this, I’m considering using torch’s WeightedRandomSampler to oversample the minority class. I took help from this post which seemed pretty straightforward. Only concern with this is the nature of my dataset.

In my case, each sample (1 point in a batch) contains 8 points. Each of these 8 points have one true class out of 100 classes. So my output shape is like this: (bs x 8). Hence, the final weight variable has total_dataset_length*8 length.

Here’s my implementation:

y_org = np.load('target.npy')                 # 5000 x 8
samples_per_class = np.unique(y_org.ravel(), return_counts=True)[1]

class_weights = class_weight.compute_class_weight(class_weight='balanced', \
                                                  classes=np.unique(y_org.ravel()), \
weights = class_weights[y_org.ravel()]
sampler = WeightedRandomSampler(weights, len(y_org.ravel()), replacement=True)

To count the number of occurrences of class index, I have to unroll (ravel) the ground truth array along the first dimension. Since the final weight variable has total_dataset_length*8 , it causes indexing errors during loading

IndexError: list index out of range

How can I use WeightedRandomSampler in such cases?

Could you explain why the samples are created as [batch_size, 8, ...] tensors?
Assuming each “point” is independent, you should be able to flatten the weights and sample each point in the __getitem__ separately. Then in the training loop you might want to create a view in your desired shape.
However, if each call into __getitem__ creates a tensor of [1, 8], then you would have to think about an approach to create a proper weight value for all 8 points.

@ptrblck The problem that I’m working on, concerns assigning a class to a bunch (of 8) vectors in 8 different bins. These 8 vectors are independent of each other and are procured directly from __getitem__ function itself. Here, each “point” comprises 8 vectors, and the dimensions of 1 point are 1x8

Does it mean the actual target is created in the __getitem__ and not known before?
If so, I doubt you could easily use a weighted sampler, since you would need to know the targets to be able to add a weight to each index.

Yes, the target is created inside the __getitem__ function. It has previously been stored in a .npy file.

class CustomDataset(Dataset):
    def __init__(self, A_path, B_path, targetA_path, targetB_path):

        self.matA = np.load(A_path)             # (65098, 8, 4)
        self.matB = np.load(B_path)             # (65098, 500, 4)
        self.targetA = np.load(targetA_path)    # (65098, 8, 500)
        self.targetB = np.load(targetB_path)    # (65098, 8)
    def __len__(self):
        return self.matA.shape[0]
    def __getitem__(self, index):
        return torch.from_numpy(self.matA[index]), \
                torch.from_numpy(self.matB[index]), \
                torch.from_numpy(self.targetA[index]), \

This is how the dataset class is defined. Here, 65098 denotes the number of datapoints, 500 the number of classes, 4 is the dimensionality of each input vector point.

I’m trying to calculate the weight matrix across targetA and then sample accordingly

Sorry, I misunderstood your use case as it seems the targets are not generated in the __getitem__, but just loaded.
Also it seems that the self.matX are just indexed and contain multiple values.
You are generally supposed to add a weight to each sample which is indexed via the passed index to __getitem__. I’m thus unsure why you would need to flatten one of the input tensors since you are directly indexing them via [index].

The class indices in targetA are in this form:

and so on ...

So flattening it gets me the number of occurrences for each class and then weighing becomes possible

Since each sample contains multiple classes, I don’t think you can easily add a weight to it, since again each index will be drawn using the weight. You might need to apply a more advanced technique as described here or just return samples containing a single target.

1 Like

Sure. I’ll look into it. Thanks, @ptrblck !