Infinite DataLoader

from torch.utils.data import DataLoader, Dataset, Sampler
import random

class listDataset(Dataset):
    def __init__(self):
        self.varList = [1,2,3,4]
    def __len__(self):
        return len(self.varList)
    def __getitem__(self, idx) :
        return self.varList[idx]

class customSampler(Sampler) :
    def __init__(self, dataset, shuffle):
        assert len(dataset) > 0
        self.dataset = dataset
        self.shuffle = shuffle
    
    def __iter__(self):
        order = list(range((len(self.dataset))))
        idx = 0
        while True:
            yield order[idx]
            idx += 1
            if idx == len(order):
                if self.shuffle:
                    random.shuffle(order)
                idx = 0

if __name__ == "__main__":
    dset = listDataset()
    sampler = customSampler(dset, shuffle=True)
    loader = iter(DataLoader(dataset=dset, sampler=sampler, batch_size=6, num_workers=2))
    for x in range(10):
        i = next(loader)
        print(i)
    
1 Like