I’m trying to concatenate more than 1 dataset and after the concatenation, looks like the index of the dataset would be not in order.
I’m having 6 images and labels related datasets, and want to pass them in a similar sequence.
This is the length of each dataset:
666
80
59
39
22
72
and this is how i concatenate them:
train_all = torch.utils.data.ConcatDataset({train_good,train_blow_hole,train_break,train_crack,train_fray,train_uneven})
The index is not the same anymore and this would causing issue in the training stage later. Is there an efficient way to concatenate the dataset while preserving their index in the concatenated dataset?
I am not super familiar with ConcatDataset, but after looking at the docs, here is how I would proceed in your situation. First, I would create a CustomDataset class that inherits PyTorch’s torch.utils.data.Dataset class so that CustomDataset has both the data and the target of each individual dataset. Second, once you have your 6 CustomDataset objects, you can use torch.utils.data.ConcatDataset to form your new dataset. I believe your mistake is in separating the data from the labels, which should be in the same class. Here is a small example:
I didn’t aware of the inherit feature until you mention it, this sure looks useful.
Actually what i’m trying to do is perform segmentation classification with unet, so in each epoch, i would need to feed in a mask and an image. Do you aware of a better way to do this?
Anyway, i will proceed with your suggestion and see how it would work, thanks again