Splitting MNIST dataset in n subset with 2 main classes per subset

Hello,

I produced a code to split the MNIST dataset into n subsets each with 2 mains classes, and then it saves the dataset and the data répartition. (ex: subset 1 should contain 882 images of class 0, and 1353 of class 9 and 0 of all the other classes.) It doesn’t work but I don’t undersand why…

This is the code:

mnist_dataset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)

def two_class_split_strat(clients):
    # à faire pour chaque classe
    global nb_clients
    global num_classes
    nb_clients = clients
    num_classes = 10 
    dupl = int(nb_clients/num_classes)

    r = []
    to_save = f'general data repartition: {save_data_distribution(mnist_dataset)}'
    cli_class = []
    print("starting", flush=True)
    for i in range(num_classes):
        r.append([])
        a = [ran.random() for _ in range(dupl*2)]
        s = sum(a)
        a = [b/s for b in a]
        l = 0

        for cli in range(nb_clients):
            if(cli % num_classes == i or cli % num_classes == (i+1)%num_classes ):
                r[-1].append(a[l])
                l += 1
            else:
                r[-1].append(0)

        print(r[-1], flush=True)
        print(sum(r[-1]), flush=True)
    
    for _ in range(nb_clients):
        cli_class.append([])
    
    for i in range(num_classes):
        mnist_subsets = []
        for j in range(0,len(mnist_dataset)):
            if mnist_dataset[j][1] == i:
                mnist_subsets.append(mnist_dataset[j])


        mnist_subsets_class = random_split(mnist_subsets, r[i], generator=torch.Generator().manual_seed(42))
    
        for m, subset in enumerate(mnist_subsets_class):
            if len(subset) == 0:
                continue
            
            if(len(cli_class[m]) == 0):
                cli_class[m] = subset
            else:
                cli_class[m] = torch.utils.data.ConcatDataset([subset, cli_class[m]])

    for l in range(nb_clients):
        to_save += f'client{l}: {save_data_distribution(cli_class[l])}\n'
        filename = f'mnist_subset_{l}.pt'
        torch.save(subset, path+filename)

    to_save += f'r: {r}\n'
    filename = 'data_distribution.txt'
    f = open(path+filename, "w")
    f.write(to_save)
    f.close()

def save_data_distribution(train_set):
    # Compute the data distribution of the training set
    train_targets = [sample[1] for sample in train_set]
    class_counts = [train_targets.count(i) for i in range(10)]
    class_distribution = [count for count in class_counts]

    return class_distribution

This is the data repartition that is saved in data_distribution.txt for n = 20 subsets:
(client0 represent the subset0)

general data repartition: [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]

client0: [882, 1, 1, 1, 1, 1, 1, 1, 1, 1353]
client1: [2678, 2038, 1, 1, 1, 1, 0, 0, 1, 1]
client2: [0, 1518, 955, 0, 0, 0, 0, 0, 0, 0]
client3: [0, 0, 1816, 295, 0, 0, 0, 0, 0, 0]
client4: [0, 0, 0, 1518, 1438, 0, 0, 0, 0, 0]
client5: [0, 0, 0, 0, 1734, 940, 0, 0, 0, 0]
client6: [0, 0, 0, 0, 0, 476, 2070, 0, 0, 0]
client7: [0, 0, 0, 0, 0, 0, 239, 2428, 0, 0]
client8: [0, 0, 0, 0, 0, 0, 0, 1354, 829, 0]
client9: [0, 0, 0, 0, 0, 0, 0, 0, 1386, 1540]
client10: [54, 0, 0, 0, 0, 0, 0, 0, 0, 1534]
client11: [2309, 1251, 0, 0, 0, 0, 0, 0, 0, 0]
client12: [0, 1934, 1742, 0, 0, 0, 0, 0, 0, 0]
client13: [0, 0, 1443, 2107, 0, 0, 0, 0, 0, 0]
client14: [0, 0, 0, 2209, 718, 0, 0, 0, 0, 0]
client15: [0, 0, 0, 0, 1950, 1715, 0, 0, 0, 0]
client16: [0, 0, 0, 0, 0, 2288, 2256, 0, 0, 0]
client17: [0, 0, 0, 0, 0, 0, 1352, 1403, 0, 0]
client18: [0, 0, 0, 0, 0, 0, 0, 1079, 2389, 0]
client19: [0, 0, 0, 0, 0, 0, 0, 0, 1245, 1521]

percentage of data per class by client
r: [[0.1487569991441706, 0.4520673047303473, 0, 0, 0, 0, 0, 0, 0, 0, 0.009199241350956257, 0.3899764547745259, 0, 0, 0, 0, 0, 0, 0, 0], 
    [0, 0.3021821171320493, 0.2252463655212973, 0, 0, 0, 0, 0, 0, 0, 0, 0.18564540590960202, 0.28692611143705143, 0, 0, 0, 0, 0, 0, 0], 
    [0, 0, 0.1603166759114562, 0.3049090970691048, 0, 0, 0, 0, 0, 0, 0, 0, 0.29246395399849495, 0.24231027302094407, 0, 0, 0, 0, 0, 0], 
    [0, 0, 0, 0.04814668026923616, 0.24773036934240952, 0, 0, 0, 0, 0, 0, 0, 0, 0.34377280062387866, 0.36035014976447566, 0, 0, 0, 0, 0], 
    [0, 0, 0, 0, 0.24630359780591282, 0.2968923062811091, 0, 0, 0, 0, 0, 0, 0, 0, 0.12297673540337589, 0.33382736050960227, 0, 0, 0, 0], 
    [0, 0, 0, 0, 0, 0.17353545008323226, 0.08794784425647079, 0, 0, 0, 0, 0, 0, 0, 0, 0.3164518273000846, 0.42206487836021245, 0, 0, 0], 
    [0, 0, 0, 0, 0, 0, 0.34983075186253837, 0.04040125657288295, 0, 0, 0, 0, 0, 0, 0, 0, 0.38130553434345765, 0.228462457221121, 0, 0], 
    [0, 0, 0, 0, 0, 0, 0, 0.3875893081540053, 0.2161609015241546, 0, 0, 0, 0, 0, 0, 0, 0, 0.22399185824655646, 0.17225793207528356, 0], 
    [0, 0, 0, 0, 0, 0, 0, 0, 0.14168683451834266, 0.2370124986910882, 0, 0, 0, 0, 0, 0, 0, 0, 0.40836477260539944, 0.21293589418516962], 
    [0.2274232031801365, 0, 0, 0, 0, 0, 0, 0, 0, 0.2589765922872685, 0.25788282591524647, 0, 0, 0, 0, 0, 0, 0, 0, 0.25571737861734856]]

Now the problem is that despite what is written in the data_distribution.txt file, all my subsets seem to contain the same data of length 1521… I couldn’t figure out what is wrong, any help is highly appreciated

Hi, @Xiaa, could you please modify the posted code to be self contained and give the output that it produces? The current output and the code seem to differ.