Hello,
I produced a code to split the MNIST dataset into n subsets each with 2 mains classes, and then it saves the dataset and the data répartition. (ex: subset 1 should contain 882 images of class 0, and 1353 of class 9 and 0 of all the other classes.) It doesn’t work but I don’t undersand why…
This is the code:
mnist_dataset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
def two_class_split_strat(clients):
# à faire pour chaque classe
global nb_clients
global num_classes
nb_clients = clients
num_classes = 10
dupl = int(nb_clients/num_classes)
r = []
to_save = f'general data repartition: {save_data_distribution(mnist_dataset)}'
cli_class = []
print("starting", flush=True)
for i in range(num_classes):
r.append([])
a = [ran.random() for _ in range(dupl*2)]
s = sum(a)
a = [b/s for b in a]
l = 0
for cli in range(nb_clients):
if(cli % num_classes == i or cli % num_classes == (i+1)%num_classes ):
r[-1].append(a[l])
l += 1
else:
r[-1].append(0)
print(r[-1], flush=True)
print(sum(r[-1]), flush=True)
for _ in range(nb_clients):
cli_class.append([])
for i in range(num_classes):
mnist_subsets = []
for j in range(0,len(mnist_dataset)):
if mnist_dataset[j][1] == i:
mnist_subsets.append(mnist_dataset[j])
mnist_subsets_class = random_split(mnist_subsets, r[i], generator=torch.Generator().manual_seed(42))
for m, subset in enumerate(mnist_subsets_class):
if len(subset) == 0:
continue
if(len(cli_class[m]) == 0):
cli_class[m] = subset
else:
cli_class[m] = torch.utils.data.ConcatDataset([subset, cli_class[m]])
for l in range(nb_clients):
to_save += f'client{l}: {save_data_distribution(cli_class[l])}\n'
filename = f'mnist_subset_{l}.pt'
torch.save(subset, path+filename)
to_save += f'r: {r}\n'
filename = 'data_distribution.txt'
f = open(path+filename, "w")
f.write(to_save)
f.close()
def save_data_distribution(train_set):
# Compute the data distribution of the training set
train_targets = [sample[1] for sample in train_set]
class_counts = [train_targets.count(i) for i in range(10)]
class_distribution = [count for count in class_counts]
return class_distribution
This is the data repartition that is saved in data_distribution.txt for n = 20 subsets:
(client0 represent the subset0)
general data repartition: [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]
client0: [882, 1, 1, 1, 1, 1, 1, 1, 1, 1353]
client1: [2678, 2038, 1, 1, 1, 1, 0, 0, 1, 1]
client2: [0, 1518, 955, 0, 0, 0, 0, 0, 0, 0]
client3: [0, 0, 1816, 295, 0, 0, 0, 0, 0, 0]
client4: [0, 0, 0, 1518, 1438, 0, 0, 0, 0, 0]
client5: [0, 0, 0, 0, 1734, 940, 0, 0, 0, 0]
client6: [0, 0, 0, 0, 0, 476, 2070, 0, 0, 0]
client7: [0, 0, 0, 0, 0, 0, 239, 2428, 0, 0]
client8: [0, 0, 0, 0, 0, 0, 0, 1354, 829, 0]
client9: [0, 0, 0, 0, 0, 0, 0, 0, 1386, 1540]
client10: [54, 0, 0, 0, 0, 0, 0, 0, 0, 1534]
client11: [2309, 1251, 0, 0, 0, 0, 0, 0, 0, 0]
client12: [0, 1934, 1742, 0, 0, 0, 0, 0, 0, 0]
client13: [0, 0, 1443, 2107, 0, 0, 0, 0, 0, 0]
client14: [0, 0, 0, 2209, 718, 0, 0, 0, 0, 0]
client15: [0, 0, 0, 0, 1950, 1715, 0, 0, 0, 0]
client16: [0, 0, 0, 0, 0, 2288, 2256, 0, 0, 0]
client17: [0, 0, 0, 0, 0, 0, 1352, 1403, 0, 0]
client18: [0, 0, 0, 0, 0, 0, 0, 1079, 2389, 0]
client19: [0, 0, 0, 0, 0, 0, 0, 0, 1245, 1521]
percentage of data per class by client
r: [[0.1487569991441706, 0.4520673047303473, 0, 0, 0, 0, 0, 0, 0, 0, 0.009199241350956257, 0.3899764547745259, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0.3021821171320493, 0.2252463655212973, 0, 0, 0, 0, 0, 0, 0, 0, 0.18564540590960202, 0.28692611143705143, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0.1603166759114562, 0.3049090970691048, 0, 0, 0, 0, 0, 0, 0, 0, 0.29246395399849495, 0.24231027302094407, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0.04814668026923616, 0.24773036934240952, 0, 0, 0, 0, 0, 0, 0, 0, 0.34377280062387866, 0.36035014976447566, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0.24630359780591282, 0.2968923062811091, 0, 0, 0, 0, 0, 0, 0, 0, 0.12297673540337589, 0.33382736050960227, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0.17353545008323226, 0.08794784425647079, 0, 0, 0, 0, 0, 0, 0, 0, 0.3164518273000846, 0.42206487836021245, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0.34983075186253837, 0.04040125657288295, 0, 0, 0, 0, 0, 0, 0, 0, 0.38130553434345765, 0.228462457221121, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0.3875893081540053, 0.2161609015241546, 0, 0, 0, 0, 0, 0, 0, 0, 0.22399185824655646, 0.17225793207528356, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0.14168683451834266, 0.2370124986910882, 0, 0, 0, 0, 0, 0, 0, 0, 0.40836477260539944, 0.21293589418516962],
[0.2274232031801365, 0, 0, 0, 0, 0, 0, 0, 0, 0.2589765922872685, 0.25788282591524647, 0, 0, 0, 0, 0, 0, 0, 0, 0.25571737861734856]]
Now the problem is that despite what is written in the data_distribution.txt file, all my subsets seem to contain the same data of length 1521… I couldn’t figure out what is wrong, any help is highly appreciated