Hello I’m study the MNIST and want to train a model with only number “1”, but I don’t know how to extract the “1” class out of the total dataset… I only know the code:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
However, your model won’t learn anything as you just have one class.
Could you explain your use case a bit?
I would at least try to keep two classes in the dataset.
Or do you want to train your model to recognize the number 1 as the valid class and all remaining numbers as false?
If so, you could try this code:
I tried using the same on CIFAR10, but it does not work.
idx = dataset.targets ==1
This part gives only one bool value ‘true’ or ‘false’ as output, and not a list.
Is it different for different Datasets or is there is a standard way of doing it?
However, I would consider these approaches as hacks and there might be some side effects I’m currently not aware of.
The clean approach would be to override the dataset class and manipulate the underlying data as you wish.
def get_indices(dataset,class_name):
indices = []
for i in range(len(dataset.targets)):
if dataset.targets[i] == class_name:
indices.append(i)
return indices
dataset = torchvision.datasets.CIFAR10(root='./data',
transform=torchvision.transforms.ToTensor())
idx = get_indices(dataset, 1)
loader = Data.DataLoader(dataset,batch_size=64, sampler = Data.sampler.SubsetRandomSampler(idx))
for idx, (data, target) in enumerate(loader):
print(target)
I hope this works for all the datasets available in torchvision.datasets.
@Sourena_Yadegari
Isn’t it also necessary to reindex the data to do something useful with this dataset?
For example the CrossEntropy criterion expects a class index in the range [0, C-1].
So, in this case 1,2 have to be reindexed as 0,1 otherwise the criterion will throw an error.
I want to create a sepearte dataloader ofr several subtasks. Each consisting identification of two digits? [0,1],[2,3],[4,5],[6,7],[8,9]? How can I do this?
You could use the same approach from this post and combine the conditions for the two wanted classes to create idx. The easiest approach would be to create 5 MNIST datasets and apply the conditions on each separately.
E.g. for the first DataLoader, this code should work: