Thanks for the code snippet!
I debugged it a bit and think you are seeing some non-determinism due to some atomic operations.
From the reproducibility docs:
There are some PyTorch functions that use CUDA functions that can be a source of non-determinism. One class of such CUDA functions are atomic operations, in particular atomicAdd
, where the order of parallel additions to the same value is undetermined and, for floating-point variables, a source of variance in the result. […]
A number of operations have backwards that use atomicAdd
, in particular torch.nn.functional.embedding_bag()
, torch.nn.functional.ctc_loss()
and many forms of pooling, padding, and sampling. There currently is no simple way of avoiding non-determinism in these functions.
I’ve compared two runs for a single epoch and since the differences are increasing I assume this might be due to the atomicAdd
in the backward function of your pooling layers.
Here are the differences for the predictions and losses:
print((preds1 - preds2).abs())
> tensor([0.0000e+00, 0.0000e+00, 1.1921e-07, 2.6822e-07, 2.9802e-08, 1.6391e-07,
1.3411e-07, 1.4901e-07, 2.9802e-07, 4.7684e-07, 8.3447e-07, 2.3842e-07,
1.3113e-06, 5.9605e-07, 2.3842e-07, 4.7684e-07, 4.1723e-06, 6.8545e-06,
9.6858e-06, 1.2815e-05, 4.9993e-06, 6.1840e-06, 5.8860e-06, 6.7577e-06,
6.6683e-06, 5.5283e-06, 5.8264e-06, 6.8545e-06, 7.7151e-06, 8.6427e-06,
7.5623e-06, 6.1095e-06, 5.7220e-06, 1.9968e-06, 2.0415e-06, 9.5367e-07,
1.5795e-06, 2.2650e-06, 1.2517e-06, 4.8578e-06, 5.9009e-06, 1.0103e-05,
1.5825e-05, 2.7329e-05, 2.5600e-05, 3.5465e-05, 2.5690e-05, 1.8418e-05,
5.1737e-05, 9.8228e-05])
print((losses1 - losses2).abs())
> tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 5.9605e-08, 0.0000e+00, 5.9605e-08,
5.9605e-08, 5.9605e-08, 1.4901e-07, 8.9407e-08, 1.4901e-07, 1.1921e-07,
1.0729e-06, 1.0431e-07, 2.3842e-07, 2.3842e-07, 2.9802e-06, 4.5300e-06,
5.9009e-06, 5.7220e-06, 2.6226e-06, 3.0398e-06, 3.0398e-06, 3.3379e-06,
3.2783e-06, 2.9802e-06, 3.0994e-06, 3.6359e-06, 3.9935e-06, 4.3511e-06,
3.6359e-06, 3.2187e-06, 2.6822e-06, 1.1325e-06, 8.3447e-07, 3.5763e-07,
8.9407e-07, 8.9407e-07, 5.9605e-07, 2.8014e-06, 3.6955e-06, 3.9935e-06,
9.5367e-06, 1.1146e-05, 1.0610e-05, 1.4484e-05, 1.0252e-05, 7.1228e-06,
3.2902e-05, 3.5226e-05])