How to aggregate gradients in pytorch?

Hello there,

My case is a bit complex. I have 8 networks, one of which is the global server for all of them.

My aim is to train my network for one epoch in those 7 networks, then try to take the weighted average of their gradients, and replace the aggregated gradients for the 8th network and make one step with aggregated gradients, also called Federated Learning. However, it does not work the way it is supposed to. My main questions are:

1- Can I take store each worker’s gradients using:

grads = []
for p in net.parameters():
       grads.append(p.grad)

2- Also, can I aggregate gradients by using sbs_grad = copy.deepcopy(temp_grad) to the first one, then adding on top of that using

                    for j in range(len(sbs_grad)):
                        sbs_grad[j] += temp_grad[j]

then dividing it by the number of workers?

3- Finally can I transfer gradients using:

        loss.backward(retain_graph=True)
        index = 0
        for p in net.parameters():
            p.grad = cluster_grads[index]
            index = index + 1
        optimizer.step()

Any help would be much appreciated.

Here is my trial for those who are interested:

On workers:

def HFL_worker_train(worker, device, optimizer, epoch, criterion, i, h, c):
    print('\n FL Epoch: %d, Cluster %d, Local Iteration %d, User %d, ' % (epoch+1, c+1, h+1, i+1))
    net = worker.worker_net
    trainloader = worker.train_DataLoader
    grads = []
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        loss.retain_grad()
        for p in net.parameters():
            grads.append(p.grad)
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                    % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    tr_acc = 100.*correct/total

Aggregating gradients:

def hfl_localtraining(H, hfl_network, device, epoch, criterion, learning_rate, hfl_worker_optimizer_list, hfl_sbs_optimizer_list, no_of_clusters, no_workers_in_cluster, c):
    for i in range(no_workers_in_cluster):
        for h in range(H):
            temp_grad = HFL_worker_train(hfl_network.Clusters[c].Workers[i], device, hfl_worker_optimizer_list[c][i], epoch, criterion, i, h, c)
            if(H - h == 1):
                if(i == 0):
                    sbs_grad = copy.deepcopy(temp_grad)
                else:
                    for j in range(len(sbs_grad)):
                        sbs_grad[j] += temp_grad[j]/no_workers_in_cluster
    
    return sbs_grad

After training all workers:

def HFL_SBS_train(worker, device, optimizer, epoch, criterion, c, cluster_grads):
    print('\n FL Epoch: %d, Cluster %d SBS' % (epoch+1, c+1))
    net = worker.worker_net
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    trainloader = worker.train_DataLoader
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.retain_grad()
        loss.backward(retain_graph=True)
        index = 0
        for p in net.parameters():
            p.grad = cluster_grads[index]
            index = index + 1
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                    % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
                    
    tr_acc = 100.*correct/total
    worker.train_loss.append(train_loss/trainloader.batch_size)
    worker.train_acc.append(100.*correct/total)

    return worker.worker_net.state_dict()