Self-defined collate_fn in DataLoader changes the data structure

(Bruce Su) #1

I have a self defined collate function which changes the shape of Tensors in the batch (padding variable-length etc.). When I run iteration twice(or multiple epochs), it throws error. After investigation, it’s because the input to the second enumerate() was the output from the first! So it’s going through the collate function twice. How to avoid it?

train_dataset = DataLoader(dataset=train_set, collate_fn=ER_Collate,
                               batch_size=BATCH_SIZE, shuffle=True)
for i, train_ in enumerate(train_dataset):
   pass
for i, train_ in enumerate(train_dataset):
   pass