I am conducting some Federated Learning experiments using a simple fully-connected PyTorch model for classification, with CrossEntropyLoss()
as the loss function. Instead of averaging the model parameters as in FedAVG, I’m summing them at each training round. After a certain number of rounds, the loss value starts to become NaN, and the accuracy drops sharply from 0.9 to 0.2, with no recovery. However, when I restrict the parameter values to the range [−1,1] using torch.clamp(params, min=-1, max=1)
at each iteration, the model continues to improve as expected.
Is this issue related to a limitation in PyTorch, or is it a theoretical problem (such as exploding gradients) that I might be overlooking?