Hello I am working with federated learning. I am following this tutorial.
Everything is understandable, however, the author is trying to use a weighted average
def server_aggregate(global_model, client_models,client_lens):
"""
This function has aggregation method 'wmean'
wmean takes the weighted mean of the weights of models
"""
total = sum(client_lens)
n = len(client_models)
global_dict = global_model.state_dict()
for k in global_dict.keys():
global_dict[k] = torch.stack([client_models[i].state_dict()[k].float()*(n*client_lens[i]/total) for i in range(len(client_models))], 0).mean(0)
global_model.load_state_dict(global_dict)
for model in client_models:
model.load_state_dict(global_model.state_dict())
Total: number of images used
n : number of deep learning models obtained
client_lens[i]: number of data samples used to train model i
As you can see, there is a for loop where the weighted average of weights is performed:
for k in global_dict.keys():
global_dict[k] = torch.stack([client_models[i].state_dict()[k].float()*(n*client_lens[i]/total) for i in range(len(client_models))], 0).mean(0)
my question is:
why is the author multiplying by n when taking the mean?:
n*client_lens[i]/total