How to get a part of datasets?

If you are fine with approx. 100 samples, which were randomly drawn, this code should work:

# Setup
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

dataset = torchvision.datasets.MNIST('./data/', train=True, transform=transform)

# Split the indices in a stratified way
indices = np.arange(len(dataset))
train_indices, test_indices = train_test_split(indices, train_size=100*10, stratify=dataset.targets)

# Warp into Subsets and DataLoaders
train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)

train_loader = DataLoader(train_dataset, shuffle=True, num_workers=2, batch_size=10)
test_loader = DataLoader(train_dataset, shuffle=False, num_workers=2, batch_size=10)


# Validation
train_targets = []
for _, target in train_loader:
    train_targets.append(target)
train_targets = torch.cat(train_targets)

print(train_targets.unique(return_counts=True))
> (tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([ 99, 112,  99, 102,  97,  90,  99, 105,  98,  99]))

Otherwise you could probably use a loop for each class to get 100 random corresponding class indices and could then use the same Subset approach.

2 Likes