Cifar10 dataset divide

Hi everyone


classes = ('plane', 'car', 'bird', 'cat',

           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,

                                        download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,

                                          shuffle=True, num_workers=2)

Here in the trainloader are all 10 classes. And for example, I have 5 clients and I want each client to become only two classes (client 1 becomes classes 1 and 2, client 2 becomes classes 3 and 4, etc.)

Any ideas how to do that?

Or is there possibility to select classes and data from trainset to make trainloader with only two specific classes??

You can use the indices where the target meets your criteria to create a Subset for each client.

classes = ('plane', 'car', 'bird', 'cat',

           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,

                                        download=True, transform=transform)

client_1_idx = np.where((np.array(trainset.targets) == 0) | (np.array(trainset.targets) == 1))[0]
client_2_idx = np.where((np.array(trainset.targets) == 2) | (np.array(trainset.targets) == 3))[0]
client_3_idx = np.where((np.array(trainset.targets) == 4) | (np.array(trainset.targets) == 5))[0]
client_4_idx = np.where((np.array(trainset.targets) == 6) | (np.array(trainset.targets) == 7))[0]
client_5_idx = np.where((np.array(trainset.targets) == 8) | (np.array(trainset.targets) == 9))[0]


client_1_ds = torch.utils.data.Subset(trainset, client_1_idx)
client_2_ds = torch.utils.data.Subset(trainset, client_2_idx)
client_3_ds = torch.utils.data.Subset(trainset, client_3_idx)
client_4_ds = torch.utils.data.Subset(trainset, client_4_idx)
client_5_ds = torch.utils.data.Subset(trainset, client_5_idx)

client_1_dl = torch.utils.data.DataLoader(client_1_ds, batch_size=128, shuffle=True, num_workers=2)
client_2_dl = torch.utils.data.DataLoader(client_2_ds, batch_size=128, shuffle=True, num_workers=2)
client_3_dl = torch.utils.data.DataLoader(client_3_ds, batch_size=128, shuffle=True, num_workers=2)
client_4_dl = torch.utils.data.DataLoader(client_4_ds, batch_size=128, shuffle=True, num_workers=2)
client_5_dl = torch.utils.data.DataLoader(client_5_ds, batch_size=128, shuffle=True, num_workers=2)
1 Like