Some problems with WeightedRandomSampler

oh sure. the warning was introduced later after that PR was merged, but we should change this. Could you open an issue?

1 Like

Sure!
I’ll open an issue and suggest a fix.

1 Like

Hello everyone,

I’m replying here since my problem is like the first question asked and I avoid opening another thread.

I’m struggling with a 3 classes problem with great unbalance. In the following code I create the weight array ( first the array with dimension [ num_classes] then the one with dimension [data_length]

weight = np.zeros(self.num_classes)
weight = 1. / count_train
self.train_samples_weight = [weight[cla] for cla in Labels_train]
self.test_samples_weight = [weight[class_id] for class_id in Labels_test]
self.train_samples_weight = np.asarray(self.train_samples_weight)
self.test_samples_weight = np.asarray(self.test_samples_weight)

and it looks correct to me. Then the code were I create the Sampler and the Loader

train_sampler = torch.utils.data.WeightedRandomSampler(self.train_samples_weight,1,replacement=True)
test_sampler = torch.utils.data.WeightedRandomSampler(self.test_samples_weight,1,replacement=True)
test_loader = DataLoader(dataset=self.dataset_test, batch_size=self.batch_size, sampler=test_sampler,drop_last=True)
train_loader = DataLoader(dataset=self.dataset_train, batch_size=self.batch_size, sampler=train_sampler,drop_last=True)

The first thing is I’m using 1 as num_samples since I want just extract a number of samples correspondant to my batch size. Anyway this is missleading since the other Sampler doesn’t work like this as pointed out here but I guess the problem is in my pytorch version.

The real issue resides in the class distribution when i use enumerate on my DataLoader. In fact I can only get samples of class_2 ( that’s the class with major number of occurrencies) but its weight seems to be the lowest and that looks correct to me. So what I’m doing wrong? I was looking for a way to express a balanced batch size and I still think that’s the correct way to go.

num_samples gives the number of samples to draw, so usually you would let it to the length of your Dataset. Currently you should only get a single batch containing one sample.

Yes, I just checked this and that’s True. Anyway I can’t understand why num_samples should be equal to my dataset dimension. In fact I already viewed example with this settings and that make no sense to me.

Without specifing the Sampler my enumerate(loader) results in batchID,data where data has dimension equal to [batchSize,…My data dimension…] besides, when I use customized Sampler I end up having a data dimneison of [ batchSize, numSamples,…My data dimension…].

Can you explain why this should work like that? I can’t see the point in this. I was expecting to still have same dimensions but with samples choosed according to new criteria.

That shouldn’t be the case and I’ve never seen this behavior before.
Could you post a code snippet, so that we can reproduce this issue?
Here is a dummy example which results in [batch_size, nb_features] for each batch:

# Create dummy data with class imbalance 99 to 1
numDataPoints = 1000
data_dim = 5
bs = 100
data = torch.randn(numDataPoints, data_dim)
target = torch.cat((torch.zeros(int(numDataPoints * 0.99), dtype=torch.long),
                    torch.ones(int(numDataPoints * 0.01), dtype=torch.long)))

print('target train 0/1: {}/{}'.format(
    (target == 0).sum(), (target == 1).sum()))

# Compute samples weight (each sample should get its own weight)
class_sample_count = torch.tensor(
    [(target == t).sum() for t in torch.unique(target, sorted=True)])
weight = 1. / class_sample_count.float()
samples_weight = torch.tensor([weight[t] for t in target])

# Create sampler, dataset, loader
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
train_dataset = torch.utils.data.TensorDataset(data, target)
train_loader = DataLoader(
    train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

# Iterate DataLoader and check class balance for each batch
for i, (x, y) in enumerate(train_loader):
    print("batch index {}, 0/1: {}/{}".format(
        i, (y == 0).sum(), (y == 1).sum()))
    print("x.shape {}, y.shape {}".format(x.shape, y.shape))
7 Likes

Hi Peter,

Is it possible that the samples_weight size is 128 x128 since I’m dealing with image dense prediction.
So every target image size is 128 x128. I have 3300 image pairs in total. (3300 original images, 3300 segmentation mask). So in this case, to deal with the class imbalance, I have done the concatenation of the target mask to find the count and class label:

for i in range(len(total_data)):
    sample = total_data[i]
    mask = sample['parc1a'].float()
    mask = mask.to(device)
    mask_total = torch.cat((mask_total, mask))

unique_color, count = np.unique(mask_total.cpu(), return_counts = True)

I have 171 classes in total.

I’m confused how to construct the sampler in my case
This samples_weight = torch.tensor([weight[t] for t in target]) doesn’t work for me since my target size is [422400, 128] so it can’t be an index

Your dummy code works fine in my environment. Anyway when I try with my data I still have the same problem. Below you can find my Dataset implementation maybe this is the source of the problem ( honestly I don’t think so cause it is really dummy class)

class triaxial_dataset(Dataset):
	"""Time Series dataset."""

	def __init__(self, data, label):
		self.data = data
	
		self.label = label

	def __len__(self):
		return self.data.shape[0]

	def __getitem__(self, idx):
		return self.data[idx], self.label[idx][0]

The only difference beetwen mine and your script is about the number of workers. I’m not using parallel processing so I’m running with the default 0. Could this change the sampler behaviour? Anyway I also tried to do it with num_workers=1 but my dataset is 30000x600x6 so it takes indefinite time to load it ( and is this normal?)

I’m not sure if I understand it correctly, but it seems you would like to sample single pixels from your masks?
That doesn’t seem to make much sense, as your input image and mask would be created by pixel information of various other images and masks.

Your use case is basically similar to a multi-label classification, i.e. each sample might have more then a single active target. This case might be a bit complicated and I’m really not sure how to set the weights properly.
If the majority class occurs often together with a minority class, you can’t really oversample the minority class. Maybe @rasbt might have some ideas on this particular case.
In the past I tried to get the class counts for all classes individually (although they appeared often in pairs) and tried to create weights by multiplying the frequencies. Still, I’m not sure, if that is the best approach.

Your Dataset just uses the index to get the current data and target, so this shouldn’t be a problem.
Also the number of workers do not change the behavior, as you can see if you set num_workers=0 in my example.

If your data has the shape [30000, 600, 6], each sample will be [batch_size, 600, 6], regardless of the sampler you are using.
In the other post you are explaining it as [batch_size, num_samples, data_dimension], but I think its just [batch_size, data_dim0, data_dim1].
Could you check it and let me know, if you still have doubts?

I checked this again and when I run:


for batch_idx, (data, target) in enumerate(train_loader):

my target.size() return


torch.Size([170, 1530, 600, 8])

that correspond to [batch_size,num_smaples,dim0,dim1] while with no Sampler selected i got [batch_size,dim0,dim1] as expected. This confuses me since your dummy script works fine.

I also have another question, when I set to a custom Sampler my execution time become really high compared to the default sampler one. Is this an expected behavior?

Did you write the custom sampler yourself or are you using the WeightedRandomSampler?
In the former case, would it be possible to have a look at your code, as there might be a bug somewhere.
The execution time depends on your sampling strategy. You could create all “static” members in the __init__ function and just sample in the __iter__ method. Generally, random samplers might be a bit slower as they call functions like torch.multinominal, but this should count as noise in comparison to the overall execution time of your script.

I’m using WeightedRandomSampler.

Considering what you are saying I should check the dimension of the array I’m passing on the Sampler creation because for a step of

enumerate(DataLoader)

this takes 10 times what it takes with no Sampler so I must have setted something wrong.

Something odd is going on.
I just tested a WeightedRandomSampler using your data shape and didn’t see any performance changes.
If it’s possible, I could try to help debugging this issue. Alternatively, if you can’t share the code, let me know if some small code snippets look strange to you so that we can have a look at them.

I can’t provide all my code since it is quite extensive. Anyway below you can find my line of code regardig the DataLoader, the Sampler and respective output. My train dimensions and test dimensions
are [1530,600,8] and [428,600,8] and my batch_size in 170

self.test_samples_weight = [weight[class_id] for class_id in Labels_test]
self.train_samples_weight = np.asarray(self.train_samples_weight)
self.test_samples_weight = np.asarray(self.test_samples_weight)
self.train_samples_weight=1./self.train_samples_weight
self.test_samples_weight=1./self.test_samples_weight
print(self.test_samples_weight.size())
print(self.train_samples_weight.size())

train_sampler = torch.utils.data.WeightedRandomSampler(self.train_samples_weight,len(self.train_samples_weight),replacement=True)
test_sampler = torch.utils.data.WeightedRandomSampler(self.test_samples_weight,len(self.test_samples_weight),replacement=True)
test_loader = DataLoader(dataset=self.dataset_test, batch_size=self.batch_size, sampler=test_sampler,drop_last=True)
train_loader = DataLoader(dataset=self.dataset_train, batch_size=self.batch_size, sampler=train_sampler,drop_last=True)

for epoch in range(self.number_of_epochs):
		for batch_idx, (data, target) in enumerate(train_loader):
			print(data.size())

and the output is:

(428, 1)
(1530, 1)
torch.Size([170, 1530, 600, 8])

And for exec the enumerate part it takes at least 2 minutes just to generate 1 batch.

Please let me know if I’m missing something important.

Thanks for the code snippet!
The issue is most likely caused by the additional dimension in your samples_weight.
Squeeze it and your code should run fine:

self.train_samples_weight = self.train_samples_weight.squeeze(1)
2 Likes

Thanks for the quick reply, you are right my code works fine now.

I was using this dimension cause working in PyTorch I usually had to reshape my input data and target in order to have the last dimension as 1 ( or at least in a LSTM or GRU environment).

Ah OK, I see. The docs currently say weights should be a sequence, but maybe we should add some more information on the shape.
What happened is, that the additional dimension treats the weights as different distributions:

weights = torch.empty(10).uniform_()
print(torch.multinomial(weights, 10, True))
> tensor([6, 6, 6, 0, 4, 2, 4, 5, 6, 6])
weights = torch.empty(10, 1).uniform_()
print(torch.multinomial(weights, 10, True))
> tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

So basically you received just the 0th sample a lot of times, since each weight row has only one value.

1 Like

Yeah I see that now while I was trying to print the samples.

Anyway I manage to get my code work but I still have doubt about the sampler. My weights for each class are these:


[0.00961538 0.00155763 0.00127551]

and that’s correct since my class_0 have only few occurrences. When I don’t specify any Sampler I get a class distribution in every batch that looks like this (with batch_size 170):


(array([0, 1, 2]), array([ 6, 75, 89], dtype=int64))

(array([0, 1, 2]), array([11, 65, 94], dtype=int64))

(array([0, 1, 2]), array([13, 80, 77], dtype=int64))

(array([0, 1, 2]), array([15, 73, 82], dtype=int64))

(array([0, 1, 2]), array([10, 66, 94], dtype=int64))

and this looks good, in fact it represents the distribution of the class in my dataset as I was expecting.

But with the weighted sampler I get:


array([0, 1, 2]), array([ 1, 66, 103], dtype=int64))

(array([0, 1, 2]), array([ 1, 75, 94], dtype=int64))

(array([0, 1, 2]), array([ 4, 72, 94], dtype=int64))

(array([0, 1, 2]), array([ 1, 75, 94], dtype=int64))

(array([0, 1, 2]), array([ 4, 61, 105], dtype=int64))

(array([0, 1, 2]), array([ 3, 61, 106], dtype=int64))

What I was expecting is to get more samples of my lower presence class but I get less samples of it instead. Furthermore sometimes there are no samples of the class_0 in my batch and this totally mess up with my metrics evaluation. Do you think this is working properly or is there still some bugs?

Based on your weights, I assume you might have multiples of this distribution:

class_counts = torch.tensor([104, 642, 784])

If so, I’ve manipulated my example code to use your weights and data distribution to get approx. equally distributed batches:

# Create dummy data with class imbalance 99 to 1
class_counts = torch.tensor([104, 642, 784])
numDataPoints = class_counts.sum()
data_dim = 5
bs = 170
data = torch.randn(numDataPoints, data_dim)

target = torch.cat((torch.zeros(class_counts[0], dtype=torch.long),
                    torch.ones(class_counts[1], dtype=torch.long),
                    torch.ones(class_counts[2], dtype=torch.long) * 2))

print('target train 0/1/2: {}/{}/{}'.format(
    (target == 0).sum(), (target == 1).sum(), (target == 2).sum()))

# Compute samples weight (each sample should get its own weight)
class_sample_count = torch.tensor(
    [(target == t).sum() for t in torch.unique(target, sorted=True)])
weight = 1. / class_sample_count.float()
samples_weight = torch.tensor([weight[t] for t in target])

# Create sampler, dataset, loader
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
train_dataset = torch.utils.data.TensorDataset(data, target)
#train_dataset = triaxial_dataset(data, target)
train_loader = DataLoader(
    train_dataset, batch_size=bs, num_workers=0, sampler=sampler)

# Iterate DataLoader and check class balance for each batch
for i, (x, y) in enumerate(train_loader):
    print("batch index {}, 0/1/2: {}/{}/{}".format(
        i, (y == 0).sum(), (y == 1).sum(), (y == 2).sum()))

> target train 0/1/2: 104/642/784
batch index 0, 0/1/2: 52/60/58
batch index 1, 0/1/2: 63/60/47
batch index 2, 0/1/2: 62/58/50
batch index 3, 0/1/2: 59/60/51
batch index 4, 0/1/2: 45/65/60
batch index 5, 0/1/2: 59/60/51
batch index 6, 0/1/2: 54/56/60
batch index 7, 0/1/2: 59/60/51
batch index 8, 0/1/2: 57/64/49

Could you compare your code to mine and let me know, if you get stuck somewhere?

1 Like