Dear All,
I am defining my own Sampler. Since there are duplicated sampling, so the length will be larger than
self.data_source
(Dataset). The tricky part is that I have NOT implemented def __len__(self)
, but no error found.
Any suggestions? Thank you.
The __len__
of the Dataset
is used by the samplers to create the indices as seen here.
If you create a custom sampler you might skip the len(data_source)
usage, but would need to make sure the created indices are not out-of-bounds as seen in this example:
# ============== Standard approach ==============
class MyDataset(Dataset):
def __init__(self):
self.data = torch.randn(10, 1)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
dataset = MyDataset()
loader = DataLoader(dataset, batch_size=5)
for x in loader:
print(x.shape)
# ============== Missing __len__ ==============
class MyDataset(Dataset):
def __init__(self):
self.data = torch.randn(10, 1)
def __getitem__(self, index):
return self.data[index]
dataset = MyDataset()
loader = DataLoader(dataset, batch_size=5)
# Fails in the sampler as __len__ is expected
for x in loader:
print(x.shape)
# TypeError: object of type 'MyDataset' has no len()
# ============== Custom Sampler ==============
class MySampler(object):
def __init__(self, length):
self.length = length
def __iter__(self):
return iter(range(self.length))
def __len__(self):
return self.length
loader = DataLoader(dataset, batch_size=5, sampler=MySampler(10))
# works
for x in loader:
print(x.shape)
# ============== Wrong length passed to the sampler ==============
loader = DataLoader(dataset, batch_size=5, sampler=MySampler(20))
# fails with IndexError
for x in loader:
print(x.shape)
# IndexError: index 10 is out of bounds for dimension 0 with size 10
1 Like
Very detailed explanation! Thank you!