How to use one class of number in MNIST

Hello I’m study the MNIST and want to train a model with only number “1”, but I don’t know how to extract the “1” class out of the total dataset… I only know the code:
train_loader =, batch_size=batch_size, shuffle=True)


You could get the indices for all class1 labels and then index the labels and data:

dataset = datasets.MNIST(root='./data')
idx = dataset.train_labels==1
dataset.train_labels = dataset.train_labels[idx]
dataset.train_data = dataset.train_data[idx]

However, your model won’t learn anything as you just have one class.
Could you explain your use case a bit?
I would at least try to keep two classes in the dataset.

Or do you want to train your model to recognize the number 1 as the valid class and all remaining numbers as false?
If so, you could try this code:

idx = dataset.train_labels != 1
dataset.train_labels[idx] = 0

Thank you! I think I actually want to train more than one class of number, it’s a typo

says “AttributeError: can’t set attribute.” on dataset.train_labels = dataset.train_labels[idx] line. What is wrong?

Could you try to access now and dataset.targets?
I think the attribute was renamed in recent torchvision versions.


thank you, it worked.

1 Like

This should work. Change the if condition in my_collate according to your needs.

def my_collate(batch):
    modified_batch = []
    for item in batch:
        image, label = item
        if label == 1:
    return default_collate(modified_batch)

train_loader =
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True, collate_fn = my_collate)

I tried using the same on CIFAR10, but it does not work.

idx = dataset.targets ==1

This part gives only one bool value ‘true’ or ‘false’ as output, and not a list.
Is it different for different Datasets or is there is a standard way of doing it?

It is most likely different.
For CIFAR10, you could wrap the targets in a tensor and use the comparison:

dataset = datasets.CIFAR10(root='./data')
idx = torch.tensor(dataset.targets)==1

I tried this,

train_dataset = datasets.CIFAR10(root='./data')
idx = torch.tensor(train_dataset.targets) == 1 =[idx]
dataset.targets = dataset.targets[idx]

I am getting this error

=>              dataset.targets = dataset.targets[idx]

TypeError: only integer tensors of a single element can be converted to an index

This should work:

dataset = datasets.CIFAR10(root='./data',
dataset.targets = torch.tensor(dataset.targets)
idx = dataset.targets==1
dataset.targets= dataset.targets[idx] =[idx.numpy().astype(np.bool)]

for data, target in dataset:

However, I would consider these approaches as hacks and there might be some side effects I’m currently not aware of.
The clean approach would be to override the dataset class and manipulate the underlying data as you wish.

Thank you, your solution works for me.

I also found another method here:

I also included an example implementation here:

def get_indices(dataset,class_name):
    indices =  []
    for i in range(len(dataset.targets)):
        if dataset.targets[i] == class_name:
    return indices

dataset = torchvision.datasets.CIFAR10(root='./data',

idx = get_indices(dataset, 1)
loader = Data.DataLoader(dataset,batch_size=64, sampler = Data.sampler.SubsetRandomSampler(idx))

for idx, (data, target) in enumerate(loader):

I hope this works for all the datasets available in torchvision.datasets.

1 Like

That looks like a cleaner approach so I would recommend to stick to the SubsetRandomSampler. :slight_smile:

1 Like

Creating a dataset as a subset and then define data loaders. Assuming you only want labels 1 and 2

idx = torch.tensor(CIFAR10_train.targets) == 1
idx += torch.tensor(CIFAR10_train.targets) == 2
dset_train =, np.where(idx==1)[0])

idx = torch.tensor(CIFAR10_test.targets) == 1
idx += torch.tensor(CIFAR10_test.targets) == 2
dset_test =, np.where(idx==1)[0])

and dataloaders are defined as usual:

dl_train =, batch_size=8, shuffle=True)
dl_test =, batch_size=8, shuffle=True)

This works too and seems to be simpler than SubsetRandomSampler.

Isn’t it also necessary to reindex the data to do something useful with this dataset?

For example the CrossEntropy criterion expects a class index in the range [0, C-1].
So, in this case 1,2 have to be reindexed as 0,1 otherwise the criterion will throw an error.

Until the number of outputs from the network/model are 2 [since you are training on 2 classes], you will not face any issues.

Also if you have 10 outputs for the network and you are only training on 2 classes, this will not have an issue too.

This condition is violated only when your data set has for example “10 classes” and your network is having “2 outputs”, but not the other way around.

I used this approach.
Here’s my implementation for split-cifar100 (a common continual learning benchmark):

def get_split_cifar100(task_id, batch_size=32, shuffle=False):
	# convention: tasks starts from 1 not 0 !
	# task_id = 1 (i.e., first task) => start_class = 0, end_class = 4
	start_class = (task_id-1)*5
	end_class = task_id * 5

	transforms = torchvision.transforms.Compose([
		torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),

	train = torchvision.datasets.CIFAR100('./data/', train=True, download=True, transform=transforms)
	test = torchvision.datasets.CIFAR100('./data/', train=False, download=True, transform=transforms)
	targets_train = torch.tensor(train.targets)
	target_train_idx = ((targets_train >= start_class) & (targets_train < end_class))

	targets_test = torch.tensor(test.targets)
	target_test_idx = ((targets_test >= start_class) & (targets_test < end_class))

	train_loader =, np.where(target_train_idx==1)[0]), batch_size=batch_size)
	test_loader =, np.where(target_test_idx==1)[0]), batch_size=batch_size)

	return train_loader, test_loader

I want to create a sepearte dataloader ofr several subtasks. Each consisting identification of two digits? [0,1],[2,3],[4,5],[6,7],[8,9]? How can I do this?

You could use the same approach from this post and combine the conditions for the two wanted classes to create idx. The easiest approach would be to create 5 MNIST datasets and apply the conditions on each separately.

E.g. for the first DataLoader, this code should work:

dataset = datasets.MNIST(root='./data')
idx = (dataset.targets==0) | (dataset.targets==1)
dataset.targets = dataset.targets[idx] =[idx]