Hello Pytorchers
I had a similar problem and I dealt with it by defining the number of iterations and an epoch counter then using them to dedicate a sample’s index and the length of the data loader. Kindly mind the answer was not fully tested, so any recommendation, fixation, or feedback is really appreciated.
class CustomImageDataset(Dataset):
def __init__(self, epoch, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
self.epoch = epoch ## step 1
self.iter = 100 ## step 2
def __len__(self):
return self.iter ## step 3 (set loader length to desired number of iterations)
def __getitem__(self, idx):
## step 4 (find new idx)
new_idx = idx + (self.iter*self.epoch)
## step 5 (handle wrap around case)
if new_idx >= len(self.img_labels):
new_idx = new_idx % len(self.img_labels)
## step 6 (the rest of your code comes here. mind that here we use new_idx to locate a data point)
img_path = os.path.join(self.img_dir, self.img_labels.iloc[new_idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[new_idx, 1]
return image, label
Obviously, this means that the dataset and dataloader must be defined within the training loop such that the parameter epoch is updated at the start of a new training epoch. e.g.,:
for epoch in range(0, epochs + 1):
dataset = CustomImageDataset(epoch=epoch, annotations_file, img_dir, transform, target_transform)
train_loader = DataLoader(dataset, batch_size=10)
train(train_loader, net, loss)
print('finsihed epoch {}'.format(epoch))
Hoping to be of help!