Uncontrolled growth of parameter values leads to NaN loss

I am conducting some Federated Learning experiments using a simple fully-connected PyTorch model for classification, with CrossEntropyLoss() as the loss function. Instead of averaging the model parameters as in FedAVG, I’m summing them at each training round. After a certain number of rounds, the loss value starts to become NaN, and the accuracy drops sharply from 0.9 to 0.2, with no recovery. However, when I restrict the parameter values to the range [−1,1] using torch.clamp(params, min=-1, max=1) at each iteration, the model continues to improve as expected.

Is this issue related to a limitation in PyTorch, or is it a theoretical problem (such as exploding gradients) that I might be overlooking?

I suspect it’s mostly exploding gradients as you said.

As the gradients increase, they might start throwing nans. Can also occur by exceeding the data type limits (although not sure on this).

Clamping essentially prevents this by sticking to a range.

1 Like

Thank you for confirming my suspects.