How I can achieve this in side dataset get

I have a python list of some items with in dataset get function . List has got just 3 to 5 elements.
Is there a function that can select distinct elements from list for every dataset item in a new epoch .
Say get (i1) epoch 1 ,select an element x
get(i1) epoch 2 ,select element y from list
Once all elements done then elements can be reselected in subsequent epochs .
Random.sample and shuffle dint guarantee selection of different elements before all elements are exhausted.

You could store an internal element counter and increase it after each epoch in the dataset.
Here is a simple example:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.randn(5, 1)
        self.elements = torch.arange(3)
        self.element_idx = 0
        
    def __getitem__(self, index):
        x = self.data[index]
        element = self.elements[self.element_idx % len(self.elements)]
        return x, element
    
    def __len__(self):
        return len(self.data)
    
dataset = MyDataset()
loader = DataLoader(dataset, batch_size=2, num_workers=2)

for epoch in range(5):
    for data, element in loader:
        print('epoch {}, element {}'.format(epoch, element))
    
    # increase element_idx
    loader.dataset.element_idx += 1

Let me know, if this would work for you.

1 Like