I am implementing a federated learning algorithm where each individual data point in a batch is held on a different client. To learn a model (say MLP), each client of the batch computes the gradient on their individual data point and updates their local model. The local models of all clients are then aggregated (averaged) by a server to update the model state.
This is different from the centralized setup where the gradient is computed on the entire batch in one-shot and then the model is updated. To illustrate,
... loss = nn.CrossEntropyLoss()(labels, logits) loss.backward(); optimizer.step() # SGD ...
Local at client i loss[i] = nn.CrossEntropyLoss()(labels[i], logits[i]) loss[i].backward(); optimizer[i].step() Global model = mean_i(model[i])
I am observing that the model computed in the federated setup is consistently diverging from the one in the centralized setup even though theoretically it is not supposed to happen since I am using the basic SGD with no randomness. Specifically, the one-shot gradient on the batch should be same as the average of the gradients of all individual points.
The initial model weights, gradient-computation rules, and the model update rules are all the same. I use full-batch gradient descent so there is no sampling of mini batches. I have tried to fix all possible sources of randomness and turned of all optimizations that I know.
torch.set_default_dtype(torch.float64) torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.allow_tf64 = False torch.backends.cudnn.determinstic = True