@smth
Hmm, I tried the code above, but got the following error:
Command:
batch_size = 20
class_sample_count = [10, 5, 2, 1]
weights = (1 / torch.Tensor(class_sample_count))
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, batch_size)
Error:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python2.7/dist-packages/torch/utils/data/sampler.py", line 81, in __init__
self.weights = torch.DoubleTensor(weights)
TypeError: torch.DoubleTensor constructor received an invalid combination of arguments - got (torch.FloatTensor), but expected one of:
* no arguments
* (int ...)
didn't match because some of the arguments have invalid types: (torch.FloatTensor)
* (torch.DoubleTensor viewed_tensor)
didn't match because some of the arguments have invalid types: (torch.FloatTensor)
* (torch.Size size)
didn't match because some of the arguments have invalid types: (torch.FloatTensor)
* (torch.DoubleStorage data)
didn't match because some of the arguments have invalid types: (torch.FloatTensor)
* (Sequence data)
didn't match because some of the arguments have invalid types: (torch.FloatTensor)