Using different custom dataloaders, with same index

Hi there,

I am currently drawing batch data from three same-shape(i.e., m by n matrix) independent data loaded on GPU all together, but due to the GPU memory usage problem I would like to write down my custom Dataset.

The thing is, I can pick item indices and use them in the three data sources before, but the dataloader actually does not provide any information about it. So my possible solution towards it is to modify getitem in my custom dataloader and use as below:

class MyDataset(Dataset):
    def __init__(self):
        self.cifar10 = datasets.CIFAR10(root='YOUR_PATH',
    def __getitem__(self, index):
        data, target = self.cifar10[index]
        return data, target, index

which is from here: []

Here comes the question: I want to pick the indices from dataloader and use the indices to get items from other two dataloaders(the data shape are the same). Would there be more efficient way?

Could you explain a bit, what you mean by

Would you like to determine the indices manually?

Would it work to initialize all three datasets inside your custom MyDataset and just index all three in __getitem__ or am I misunderstanding the use case?

I think I found a walkaround that initialize dataset independently and concatenate them in one dataset, then index it using the key information during the dataloading. Data from the same row comes out from datasets. The shuffle=True in the dataloader seems to shuffle all the concatenated datasets in the same shuffled indices. Specific code is as below and it works for me:

class ContextTensorDataset(Dataset):

    def __init__(self, context_type, directory_path, context_threshold=0, context_neglect=False):

        print(context_type + " Context Tensor Dataset initialized.")
        # Read data if the context is provided
        self.context_matrix = T.FloatTensor(np.load(directory_path + 'user_' + context_type + '_context.npy'))
        self.context_threshold = context_threshold
        self.context_type = context_type

        if context_neglect:
            self.context_matrix[self.context_matrix < self.context_threshold] = 0

    def __getitem__(self, index):
        return {self.context_type: self.context_matrix[index]}

    def __len__(self):
        Returns the length of the data(i.e., number of users)
        return self.context_matrix.shape[0]

    def shuffle_index(self, new_index):
        self.context_matrix = self.context_matrix[new_index]

    def shape(self):
        return self.context_matrix.shape

class ConcatDatasets(Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, index):
        batch = {}
        for dataset in self.datasets:
            batch = {**batch, **dataset[index]}
        return batch

    def __len__(self):
        return min(len(d) for d in self.datasets)

thus the code in the would be somewhat like below

    if 'G' in args.input_val:
        place_data = ContextTensorDataset('geo', directory_path, args.context, context_neglect=True)
        datasets += (place_data,)
    if 'T' in args.input_val:
        time_data = ContextTensorDataset('time', directory_path, args.context, context_neglect=True)
        datasets += (time_data,)
    if 'S' in args.input_val:
        seq_data = ContextTensorDataset('seq', directory_path, args.context, context_neglect=True)
        datasets += (seq_data,)

    dataloaders = DataLoader(ConcatDatasets(*datasets),