Hi all, recently I also have this problem and my Sampler is using WeightedRandomSampler
. I tried to added @albanD 's solution but unable to do so with my custom sampler.
class WeightedRandomSampler2(torch.utils.data.Sampler):
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
Args:
weights (sequence) : a sequence of weights, not necessary summing up to one
num_samples (int): number of samples to draw
replacement (bool): if ``True``, samples are drawn with replacement.
If not, they are drawn without replacement, which means that when a
sample index is drawn for a row, it cannot be drawn again for that row.
Example:
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[0, 0, 0, 1, 0]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
"""
def __init__(self, weights, num_samples, replacement=True):
super().__init__()
if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples))
if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(replacement))
self.weights = torch.as_tensor(weights, dtype=torch.double)
self.num_samples = num_samples
self.replacement = replacement
def __iter__(self):
return iter(self.start_idx, torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
def __len__(self):
return self.num_samples
That is by adding self.start_idx
in __iter__()
I received the error when running it in the script:
data_sampler = WeightedRandomSampler2(sample_weights, num_samples=len(sample_weights), replacement=True)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-14-b420c064a884> in <module>()
----> 1 data_sampler = WeightedRandomSampler2(sample_weights, num_samples=len(sample_weights), replacement=True)
<ipython-input-10-3c0a1378cd2d> in __init__(self, weights, num_samples, replacement)
20 def __init__(self, weights, num_samples, replacement=True):
21
---> 22 super().__init__()
23
24 if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or num_samples <= 0:
TypeError: __init__() missing 1 required positional argument: 'data_source'
Any idea?