Stagnation in the loss

I am building a framework for decentralized learning, a fully connected graph. The behavior of a single node with batch size 256 should be the same as 16 nodes with batch size 16 (when distributing the full data among 16 nodes randomly). My code works for a single client, however, once I increase the number of clients the loss got stagnating at almost 2.3 value and after a certain number of steps, it starts decreasing (after how many steps it starts decreasing depends on how much I increased the number of clients 4,8,16…), but all experiments stagnate in the same value of loss. I am not understanding why this happening.

Client class

class Client(object):

    def __init__(self, client_id, local_steps, task, learning_rate, batch_size, device, lambd, model, train_loaders,
                 test_loaders):
        self.client_id = client_id
        self.device = device
        self.local_steps = local_steps
        self.task = task
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        # added for regularization
        self.lambd = lambd
        self.model = model
        self.optimizer=torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.0)
        self.criterion = torch.nn.CrossEntropyLoss()
        # train
        self.train_set = train_loaders
        self.train_loader = iter(self.train_set)
        # test
        self.test_set = test_loaders
        self.test_loader = iter(self.test_set)

    def train(self, model):
        self.model = model

        ac=[]
        los=[]
        run_loss=0
        for completed_steps in range(self.local_steps):
            correct = 0
            total = 0
            try:
                inputs, labels = next(self.train_loader)
            except StopIteration:
                print("except train")
                self.train_loader = iter(self.train_set)
                inputs, labels = next(self.train_loader)
            # zero the parameter gradients
            self.optimizer.zero_grad()
            # forward + backward + optimize
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            run_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            acc = 100. * correct / total
            ac.append(acc)
            los.append(run_loss)

        acc=sum(ac)/len(ac)
        run_loss=sum(los)/len(los)
        results = {'clientId': self.client_id, 'update_weight': self.model.state_dict(), 'train_acc': acc}
        results['train_loss'] = run_loss
        return results

The class responsible for training clients:

class trainClients(object):

    def __init__(self, cfg, hydra_cfg):

        # self.device = 'cuda' if args.gpu else 'cpu'
        self.cfg = cfg
        self.device = torch.device('cuda:0') if self.cfg.device else torch.device('cpu')

        self.training_sets = self.test_dataset = None
        self.models = []
        self.optimizers = []
        self.epoch = 0
        self.client_id = 0
        self.hydra_cfg = hydra_cfg
        self.aggregation = False

    def train_clients(self):

        print("################ CIFAR TRAIN FUNCTION")
        #WANDB log
        if self.cfg.use_wandb:
            if str(self.hydra_cfg["split"]) == "non_iid_split":
                missing = self.cfg.split.missing
                a = str(self.cfg.split.a)
            else:
                missing = "_"
                a = "_"

            top_name = self.cfg.split.top_name
            run_name = "DL" + "_" + str(self.cfg.decentralized) + "_" + "Clt" + "_" + str(
                self.cfg.num_clients) + "_" + "Rds" + "_" + str(self.cfg.rounds) + "_" + "Stp" + "_" + str(
                self.cfg.client_training.local_steps) + "_" + "Top" + "_" + str(
                top_name) + "_a_" + a + "_" + "Opt" + "_" + str(self.hydra_cfg["optim"]) + "_" + "Data" + "_" + str(
                self.hydra_cfg["datamodule"]) + "_Splt_" + str(self.hydra_cfg["split"]) + "_miss_" + str(missing)
            time_stamp = datetime.datetime.fromtimestamp(time.time()).strftime('%m%d_%H%M%S')
            print(wandb.config)
            wandb_run = instantiate(self.cfg.logger, id=run_name + '_' + "time_stamp= " + time_stamp)
            print(self.cfg.logger)
            self.cfg.client_training.client_id = self.client_id
            wandb.config = OmegaConf.to_container(
                self.cfg, resolve=True, throw_on_missing=True
            )

        datamodule = instantiate(self.cfg.datamodule)
        # train and test data:
        log.info("load_and_split")
        train_sets, test_sets = datamodule.load_and_split(self.cfg)

        # Create number of clients models and Create clients
        log.info("Preparing the clients and the models and datasets..")
        clients = []
        self.cfg.model.num_classes = datamodule.num_classes
        for i in range(self.cfg.num_clients):
            print("i",i)
            # Model
            model = instantiate(self.cfg.model)
            model.apply(initialize_weights)
            self.models.append(model)

            # Dataset
            train_loader = datamodule.data_loaders(train_sets)
            test_loader = datamodule.test_loader(test_sets)

            # client object
            self.cfg.client_training.client_id = self.client_id
            client = instantiate(self.cfg.client_training, device=self.device, model=model, train_loaders=train_loader,
                                 test_loaders=test_loader)
            clients.append(client)

            self.client_id += 1
            datamodule.next_client()

        self.client_id = 0
        datamodule.set_client()



        #Connections between clients:
        adjacency_matrix = graph_generation(self.cfg.num_clients)
        print(adjacency_matrix)

        log.info("############# Start training ###############")
        for t in range(self.cfg.rounds):
            log.info("####### This is ROUND number {}  ######".format(t))
            datamodule.set_client()
            self.client_id = 0
            batch = 0
            batches = self.cfg.split.data_points / self.cfg.datamodule.batch_size
            log.info("batch SIZE {}".format(int(batches)))
            test_acc, train_acc, test_loss, train_loss = [], [], [], []
            while batch < int(batches):
                datamodule.set_client()
                self.client_id = 0
                local_weights = OrderedDict()
                # for client in clients:
                for i in range(len(clients)):
        
                    results = clients[i].train(self.models[clients[i].client_id])
                    train_acc.append(results["train_acc"])
                    train_loss.append(results["train_loss"])

                    # test
                    test_results = clients[i].test()
                    test_acc.append(test_results["test_acc"])
                    test_loss.append(test_results["test_loss"])

                    if i==0:
                        wandb.log({'client/train_loss': results["train_loss"],
                                   'client/test_loss': test_results["test_loss"],
                                   'client/train_accuracy': results["train_acc"],
                                   'client/test_accuracy': test_results["test_acc"]
                                   },
                                  step=batch + (t * int(batches)))

                if batch % 1 == 0 and batch != 0:
                    log.info("aggregation step ......")

                    c=0
                    for model in self.models:
                        local_weights[c] = model.state_dict()
                        c=c+1

                    self.models= decentralized_average(local_weights, self.models)

                    train_acc_batch = sum(train_acc) / len(train_acc)
                    train_loss_batch = sum(train_loss) / len(train_loss)
                    test_acc_batch = sum(test_acc) / len(test_acc)
                    test_loss_batch = sum(test_loss) / len(test_loss)
                    wandb.log({'Batch/train_loss': train_loss_batch,
                               'Batch/test_loss': test_loss_batch,
                               'Batch/train_accuracy': train_acc_batch,
                               'Batch/test_accuracy': test_acc_batch
                               },
                              step=batch + (t * int(batches))) #step=
                    test_acc, train_acc, test_loss, train_loss = [], [], [], []

                batch += 1

aggregation function:

def decentralized_average(w,models):
    """
    Returns the average of the weights.
    this function is correct only for the case of fully connected graph
    """

    weights = list(w.values()) 
    w_avg = copy.deepcopy(weights[0]) 
    
    for key in w_avg.keys(): # keys are layer names
        for i in range(1, len(weights)): # for all clients (all weights)
            w_avg[key] += weights[i][key]  #summing for that layer of name key all weights
        w_avg[key] = torch.div(w_avg[key], len(weights)) #dividing to do the average

    for i in range(len(models)):
        models[i].load_state_dict(w_avg) # all n models will have the same weights which is the average

    return models

Topology

def graph_generation(num_clients):
    seed = 31
    np.random.seed(seed)
    G = networkx.binomial_graph(num_clients, p=1, seed=42)  # 50 nodes, random probability of an edge

    adjacency_matrix = networkx.adjacency_matrix(G).todense() + np.identity(num_clients)
    graph_drawing(G, name="binomial_graph")
    return adjacency_matrix / adjacency_matrix.sum(1)

The blue is 16 clients and the green is one client

1 Like