Implementation 1
Fed Averaging and Loading weights same loop
for k in global_dict.keys():
global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() for i
in range(len(client_models))], 0).mean(0)
global_model.load_state_dict(global_dict)
for model in client_models:
model.load_state_dict(global_model.state_dict())
Implementation 2
Fed Averaging and Loading weights different loop
for k in global_dict.keys():
global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() for i
in range(len(client_models))], 0).mean(0)
global_model.load_state_dict(global_dict)
global_model.load_state_dict(global_dict)
for model in client_models:
model.load_state_dict(global_model.state_dict())
Implementation 2 gives me a test accuracy of the global model similar to the client model but implementation 1 does not. I don’t know the reason behind it ?
Also, I am using a pre-trained model with additional layers. I want to keep the layers of the pre-trained model frozen during training for each client.