How does ConcatDataset work?

Hello. This is my CustomDataSetClass:

class CustomDataSet(Dataset):
    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform
        all_imgs = os.listdir(main_dir)
        self.total_imgs = natsort.natsorted(all_imgs)
        for file_name in self.total_imgs:
            if '.txt' in file_name: self.total_imgs.remove(file_name)
            if file_name == 'semantic': self.total_imgs.remove('semantic')

    def __len__(self):
        return len(self.total_imgs)

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
        image = Image.open(img_loc).convert("RGB")
        tensor_image = self.transform(image)
        return tensor_image

Here is how I create a list of datasets:

    all_datasets = []
    while folder_counter < num_train_folders:
        #some code to get path_to_imgs which is the location of the image folder
        train_dataset = CustomDataSet(path_to_imgs, transform)
        all_datasets.append(train_dataset)
        folder_counter += 1

Then I concat my datasets and create the dataloader and do the training:

final_dataset = torch.utils.data.ConcatDataset(all_datasets)
train_loader = data.DataLoader(final_dataset,
                                   batch_size=batch_size,
                                   shuffle=False,
                                   num_workers=0,
                                   pin_memory=True,
                                   drop_last=True)

So, is the order of my data preserved? During training, will I go to each folder in theexact order that the concatenation was done and then grab all the images sequentially? For example:

I grab 150 images from folder 1, 100 images from folder 2 and 70 images from folder 3. I concatenate my the three datasets. During training I do:

for idx, input_seq in enumerate(data_loader):
#code to train

So, will the dataloader go through folder 1 and grab all the images inside there sequentially and then go to folder 2 and do the same and finally go to folder 3 and do the same as well? I tried reading the code for ConcatDataset but I can’t understand whether the order of my data willl be preserved or not.

1 Like

Yes, the order should be preserved as shown in this simple example using TensorDatasets:

datasets = []
for i in range(3):
    datasets.append(TensorDataset(torch.arange(i*10, (i+1)*10)))

dataset = ConcatDataset(datasets)
loader = DataLoader(
    dataset,
    shuffle=False,
    num_workers=0,
    batch_size=2
)

for data in loader:
    print(data)
7 Likes

what if I want to grab data from different file but after i want it to be concatinated and shuffled not preserving the same order
because I have data in different folders so i grabe each of them but afterwards i want it to be all shuffled.

You could use shuffle=True when creating the DataLoader, which will shuffle the passed ConcatDataset.

2 Likes
list_1 = [1,2,3,4,5]
list_2 = [6,7,8,9,10]
list_3 = [22,23,24,25,26,27]
dataset_list = [list_1, list_2, list_3]

dataset_loader = DataLoader(dataset_list, shuffle=True, batch_size=3)
for i in range(30):
    for x in dataset_loader:
        print(x)

My question is why does each batch(size 3) has same data but shuffled within themselves?
My usecase is, I need to shuffle the entire dataset after concatenation, such that over each epoch I have different batch of dataset shuffled over all the datasets[list_1, list_2, list_3]

Passing nested lists to the DataLoader might have these kind of side effects and thus I would recommend to create tensors, pass them to a TensorDataset, and this dataset to the DataLoader, which should then properly index and shuffle the data.

1 Like

Sure Thanks :slight_smile:

Hi,
I’m trying to use ConcatDataset to concat the training and testing sets of the CIFAR10 dataset.
For my application, I also need to combine the dataset.target values for each set. But from my understanding, I cannot derive combined targets using ConcatDataset.
Please suggest how to go about concatenating the datasets such that I can get targets as well.

Double post from here.