If you created your own Dataset
, then you could do this lazily.
You could define a variable inside your Dataset
that keeps track of what iteration it is, and according to it you could change what the returned label is. If you want, you can also add methods to update this variable, or change it directly when your epoch/iteration is done.
Here is a small example of how you could do it. Then, when you call __getitem__
you will get the corresponding label. In this example the label is the same as self.iteration
, but you would need to change it to whatever it is that you need.
# Dataset example
class MyCustomDataset(torch.utils.data.Dataset):
def __init__(self, *args, **kwargs):
self.data = list(range(10))
self.iteration = 0
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data_point = self.data[idx]
# Here comes your logic to change the label according to the iteration
label = self.iteration
return data_point, label
def change_iteration(self, iteration):
self.iteration = iteration
def reset_iteration(self):
self.iteration = 0
# Small usage example
ds = MyCustomDataset()
dl = torch.utils.data.DataLoader(ds, batch_size=10)
for epoch in range(3):
ds.change_iteration(epoch)
for i, (data, lbl) in enumerate(dl):
print(data, lbl)
# Output
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2])