Hello,
I have implemented the Bucket Iterator using the DataLoader as done in the torchtext legacy migration guide.
However, this implementation results in the bucket DataLoader expiring after the first epoch. The only way I’ve been able to overcome this is by instantiating a new DataLoader inside the for loop going over the epochs.
Why is this DataLoader not resetting like it usually should?
Below is a snippet of my code for reference.
dataset = EnFRDataset() #map_style Dataset, each sample is a pair of english-french sentences
splits = [int(0.8*len(dataset))+2, int(0.1*len(dataset)), int(0.1*len(dataset))]
train, val, test = random_split(dataset, splits, generator=torch.Generator().manual_seed(41))
def collate_batch(batch):
src_list, trg_list = [], []
for data in batch:
src_list.append(data[0])
trg_list.append(data[1])
src = pad_sequence(src_list, padding_value=dataset.en_vocab['<PAD>'], batch_first=True)
trg = pad_sequence(trg_list, padding_value=dataset.fr_vocab['<PAD>'], batch_first=True)
return src, trg
def batch_sampler(subset):
indices = [(i, len(data[0])) for i, data in enumerate(subset)]
random.shuffle(indices)
pooled_indices = []
# create pool of indices with similar lengths
for i in range(0, len(indices), batch_size * 100):
pooled_indices.extend(sorted(indices[i:i + batch_size * 100], key=lambda x: x[1]))
pooled_indices = [x[0] for x in pooled_indices]
# yield indices for current batch
for i in range(0, len(pooled_indices), batch_size):
yield pooled_indices[i:i + batch_size]
batch_size = 10
for epoch in range(5):
#without this line the Dataloader will expire after the first epoch
bucket_dataloader = DataLoader(train, batch_sampler=batch_sampler(train), collate_fn=collate_batch)
for _, (X, y) in enumerate(bucket_dataloader):
#train model
This issue was also raised in another post that was doing the same thing but never got a response.
Thanks.