Hi everyone
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=2)
Here in the trainloader are all 10 classes. And for example, I have 5 clients and I want each client to become only two classes (client 1 becomes classes 1 and 2, client 2 becomes classes 3 and 4, etc.)
Any ideas how to do that?
Or is there possibility to select classes and data from trainset to make trainloader with only two specific classes??
You can use the indices where the target meets your criteria to create a Subset
for each client.
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
client_1_idx = np.where((np.array(trainset.targets) == 0) | (np.array(trainset.targets) == 1))[0]
client_2_idx = np.where((np.array(trainset.targets) == 2) | (np.array(trainset.targets) == 3))[0]
client_3_idx = np.where((np.array(trainset.targets) == 4) | (np.array(trainset.targets) == 5))[0]
client_4_idx = np.where((np.array(trainset.targets) == 6) | (np.array(trainset.targets) == 7))[0]
client_5_idx = np.where((np.array(trainset.targets) == 8) | (np.array(trainset.targets) == 9))[0]
client_1_ds = torch.utils.data.Subset(trainset, client_1_idx)
client_2_ds = torch.utils.data.Subset(trainset, client_2_idx)
client_3_ds = torch.utils.data.Subset(trainset, client_3_idx)
client_4_ds = torch.utils.data.Subset(trainset, client_4_idx)
client_5_ds = torch.utils.data.Subset(trainset, client_5_idx)
client_1_dl = torch.utils.data.DataLoader(client_1_ds, batch_size=128, shuffle=True, num_workers=2)
client_2_dl = torch.utils.data.DataLoader(client_2_ds, batch_size=128, shuffle=True, num_workers=2)
client_3_dl = torch.utils.data.DataLoader(client_3_ds, batch_size=128, shuffle=True, num_workers=2)
client_4_dl = torch.utils.data.DataLoader(client_4_ds, batch_size=128, shuffle=True, num_workers=2)
client_5_dl = torch.utils.data.DataLoader(client_5_ds, batch_size=128, shuffle=True, num_workers=2)
1 Like