Concat dataset and preserving their index


I’m trying to concatenate more than 1 dataset and after the concatenation, looks like the index of the dataset would be not in order.

I’m having 6 images and labels related datasets, and want to pass them in a similar sequence.

This is the length of each dataset:

and this is how i concatenate them:
train_all ={train_good,train_blow_hole,train_break,train_crack,train_fray,train_uneven})

train_all_label ={train_label_good,train_label_blow_hole,train_label_break,train_label_crack,train_label_fray,train_label_uneven})

After the concatenation:

[72, 738, 797, 877, 899, 938]

[72, 111, 777, 836, 858, 938]

The index is not the same anymore and this would causing issue in the training stage later. Is there an efficient way to concatenate the dataset while preserving their index in the concatenated dataset?


is there anyone have any idea to solve the issue or possibly to provide some reference?


I am not super familiar with ConcatDataset, but after looking at the docs, here is how I would proceed in your situation. First, I would create a CustomDataset class that inherits PyTorch’s class so that CustomDataset has both the data and the target of each individual dataset. Second, once you have your 6 CustomDataset objects, you can use to form your new dataset. I believe your mistake is in separating the data from the labels, which should be in the same class. Here is a small example:

class CustomDataset(
    def __init__(self, data, target): = data = target

    def __getitem__(self, index):

    def __len__(self):
        return len(

x = torch.rand(size=(666, 10), dtype=torch.float32)
x_labels = torch.zeros(size=(666,), dtype=torch.int32)
y = torch.rand(size=(80, 10), dtype=torch.float32)
y_labels = torch.ones(size=(80,), dtype=torch.int32)
z = torch.rand(size=(59, 10), dtype=torch.float32)
z_labels = torch.zeros(size=(59,), dtype=torch.int32)

x_dataset = CustomDataset(x, x_labels)
y_dataset = CustomDataset(y, y_labels)
z_dataset = CustomDataset(z, z_labels)

train_dataset =[x_dataset, y_dataset, z_dataset])

Hope it helps!

Hi beaupreda,

Thanks and appreciate the information.

I didn’t aware of the inherit feature until you mention it, this sure looks useful.

Actually what i’m trying to do is perform segmentation classification with unet, so in each epoch, i would need to feed in a mask and an image. Do you aware of a better way to do this?

Anyway, i will proceed with your suggestion and see how it would work, thanks again