How to use one class of number in MNIST

Hello I’m study the MNIST and want to train a model with only number “1”, but I don’t know how to extract the “1” class out of the total dataset… I only know the code:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

Thanks

You could get the indices for all class1 labels and then index the labels and data:

dataset = datasets.MNIST(root='./data')
idx = dataset.train_labels==1
dataset.train_labels = dataset.train_labels[idx]
dataset.train_data = dataset.train_data[idx]

However, your model won’t learn anything as you just have one class.
Could you explain your use case a bit?
I would at least try to keep two classes in the dataset.

Or do you want to train your model to recognize the number 1 as the valid class and all remaining numbers as false?
If so, you could try this code:

idx = dataset.train_labels != 1
dataset.train_labels[idx] = 0
2 Likes

Thank you! I think I actually want to train more than one class of number, it’s a typo

says “AttributeError: can’t set attribute.” on dataset.train_labels = dataset.train_labels[idx] line. What is wrong?

Could you try to access now dataset.data and dataset.targets?
I think the attribute was renamed in recent torchvision versions.

1 Like

thank you, it worked.

1 Like

This should work. Change the if condition in my_collate according to your needs.

def my_collate(batch):
    modified_batch = []
    for item in batch:
        image, label = item
        if label == 1:
            modified_batch.append(item)
    return default_collate(modified_batch)

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True, collate_fn = my_collate)

I tried using the same on CIFAR10, but it does not work.

idx = dataset.targets ==1

This part gives only one bool value ‘true’ or ‘false’ as output, and not a list.
Is it different for different Datasets or is there is a standard way of doing it?

It is most likely different.
For CIFAR10, you could wrap the targets in a tensor and use the comparison:

dataset = datasets.CIFAR10(root='./data')
idx = torch.tensor(dataset.targets)==1
1 Like

I tried this,

train_dataset = datasets.CIFAR10(root='./data')
idx = torch.tensor(train_dataset.targets) == 1
dataset.data = dataset.data[idx]
dataset.targets = dataset.targets[idx]

I am getting this error

=>              dataset.targets = dataset.targets[idx]

TypeError: only integer tensors of a single element can be converted to an index

This should work:

dataset = datasets.CIFAR10(root='./data',
                           transform=transforms.ToTensor())
dataset.targets = torch.tensor(dataset.targets)
idx = dataset.targets==1
dataset.targets= dataset.targets[idx]
dataset.data = dataset.data[idx.numpy().astype(np.bool)]

for data, target in dataset:
    print(data.sum())
    print(target)

However, I would consider these approaches as hacks and there might be some side effects I’m currently not aware of.
The clean approach would be to override the dataset class and manipulate the underlying data as you wish.

Thank you, your solution works for me.

I also found another method here:

I also included an example implementation here:

def get_indices(dataset,class_name):
    indices =  []
    for i in range(len(dataset.targets)):
        if dataset.targets[i] == class_name:
            indices.append(i)
    return indices


dataset = torchvision.datasets.CIFAR10(root='./data',
                           transform=torchvision.transforms.ToTensor())

idx = get_indices(dataset, 1)
loader = Data.DataLoader(dataset,batch_size=64, sampler = Data.sampler.SubsetRandomSampler(idx))

for idx, (data, target) in enumerate(loader):
    print(target)

I hope this works for all the datasets available in torchvision.datasets.

1 Like

That looks like a cleaner approach so I would recommend to stick to the SubsetRandomSampler. :slight_smile:

1 Like

Creating a dataset as a subset and then define data loaders. Assuming you only want labels 1 and 2

idx = torch.tensor(CIFAR10_train.targets) == 1
idx += torch.tensor(CIFAR10_train.targets) == 2
dset_train = torch.utils.data.dataset.Subset(CIFAR100_train, np.where(idx==1)[0])

idx = torch.tensor(CIFAR10_test.targets) == 1
idx += torch.tensor(CIFAR10_test.targets) == 2
dset_test = torch.utils.data.dataset.Subset(CIFAR100_test, np.where(idx==1)[0])

and dataloaders are defined as usual:

dl_train = torch.utils.data.DataLoader(dset_train, batch_size=8, shuffle=True)
dl_test = torch.utils.data.DataLoader(dset_test, batch_size=8, shuffle=True)
1 Like

This works too and seems to be simpler than SubsetRandomSampler.

@Sourena_Yadegari
Isn’t it also necessary to reindex the data to do something useful with this dataset?

For example the CrossEntropy criterion expects a class index in the range [0, C-1].
So, in this case 1,2 have to be reindexed as 0,1 otherwise the criterion will throw an error.

Until the number of outputs from the network/model are 2 [since you are training on 2 classes], you will not face any issues.

Also if you have 10 outputs for the network and you are only training on 2 classes, this will not have an issue too.

This condition is violated only when your data set has for example “10 classes” and your network is having “2 outputs”, but not the other way around.