You can use the indices where the target meets your criteria to create a Subset
for each client.
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
client_1_idx = np.where((np.array(trainset.targets) == 0) | (np.array(trainset.targets) == 1))[0]
client_2_idx = np.where((np.array(trainset.targets) == 2) | (np.array(trainset.targets) == 3))[0]
client_3_idx = np.where((np.array(trainset.targets) == 4) | (np.array(trainset.targets) == 5))[0]
client_4_idx = np.where((np.array(trainset.targets) == 6) | (np.array(trainset.targets) == 7))[0]
client_5_idx = np.where((np.array(trainset.targets) == 8) | (np.array(trainset.targets) == 9))[0]
client_1_ds = torch.utils.data.Subset(trainset, client_1_idx)
client_2_ds = torch.utils.data.Subset(trainset, client_2_idx)
client_3_ds = torch.utils.data.Subset(trainset, client_3_idx)
client_4_ds = torch.utils.data.Subset(trainset, client_4_idx)
client_5_ds = torch.utils.data.Subset(trainset, client_5_idx)
client_1_dl = torch.utils.data.DataLoader(client_1_ds, batch_size=128, shuffle=True, num_workers=2)
client_2_dl = torch.utils.data.DataLoader(client_2_ds, batch_size=128, shuffle=True, num_workers=2)
client_3_dl = torch.utils.data.DataLoader(client_3_ds, batch_size=128, shuffle=True, num_workers=2)
client_4_dl = torch.utils.data.DataLoader(client_4_ds, batch_size=128, shuffle=True, num_workers=2)
client_5_dl = torch.utils.data.DataLoader(client_5_ds, batch_size=128, shuffle=True, num_workers=2)