Dataloader stops after first epoch

Hello,

I have implemented the Bucket Iterator using the DataLoader as done in the torchtext legacy migration guide.

However, this implementation results in the bucket DataLoader expiring after the first epoch. The only way I’ve been able to overcome this is by instantiating a new DataLoader inside the for loop going over the epochs.

Why is this DataLoader not resetting like it usually should?

Below is a snippet of my code for reference.

dataset = EnFRDataset() #map_style Dataset, each sample is a pair of english-french sentences
splits = [int(0.8*len(dataset))+2, int(0.1*len(dataset)), int(0.1*len(dataset))]
train, val, test = random_split(dataset, splits, generator=torch.Generator().manual_seed(41))

def collate_batch(batch):
    src_list, trg_list = [], []
    for data in batch:
        src_list.append(data[0])
        trg_list.append(data[1])
    src = pad_sequence(src_list, padding_value=dataset.en_vocab['<PAD>'], batch_first=True)
    trg = pad_sequence(trg_list, padding_value=dataset.fr_vocab['<PAD>'], batch_first=True)
    return src, trg

def batch_sampler(subset):
    indices = [(i, len(data[0])) for i, data in enumerate(subset)]
    random.shuffle(indices)
    pooled_indices = []
    # create pool of indices with similar lengths 
    for i in range(0, len(indices), batch_size * 100):
        pooled_indices.extend(sorted(indices[i:i + batch_size * 100], key=lambda x: x[1]))
    pooled_indices = [x[0] for x in pooled_indices]
    # yield indices for current batch
    for i in range(0, len(pooled_indices), batch_size):
        yield pooled_indices[i:i + batch_size]

batch_size = 10

for epoch in range(5):
    #without this line the Dataloader will expire after the first epoch
    bucket_dataloader = DataLoader(train, batch_sampler=batch_sampler(train), collate_fn=collate_batch)
    for _, (X, y) in enumerate(bucket_dataloader):
        #train model

This issue was also raised in another post that was doing the same thing but never got a response.

Thanks.