Torch.utils.data.dataset.random_split

Here is a small example:

class MyDataset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
        
    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y
        
    def __len__(self):
        return len(self.subset)

    
init_dataset = TensorDataset(
    torch.randn(100, 3, 24, 24),
    torch.randint(0, 10, (100,))
)

lengths = [int(len(init_dataset)*0.8), int(len(init_dataset)*0.2)]
subsetA, subsetB = random_split(init_dataset, lengths)
datasetA = MyDataset(
    subsetA, transform=transforms.Normalize((0., 0., 0.), (0.5, 0.5, 0.5))
)
datasetB = MyDataset(
    subsetB, transform=transforms.Normalize((0., 0., 0.), (0.5, 0.5, 0.5))
)

Let me know, if that would work for you.

19 Likes