TIMM-based ViT masked-image modeling numerically unstable

I am training vision transformers using the masked-image modeling paradigm of MAE (https://arxiv.org/abs/2111.06377), and building off of the paper’s codebase too, but with ViT-tiny models run on the CIFAR-10. The transformers built with TIMM components, are pre-norm, trained with learning rate warmup, xavier uniform initialization, and in pure 32-bit precision.

I can produce deterministic results by setting determinism, seed, and various PyTorch backends, eg. which attention algorithm is used, or whether I use TensorFloat32. I further find that the linear probes I use to evaluate my pretrained models are robust to changes to these settings—all validation accuracies typically fall within a 0.1% interval.

However, when varying just these backends for pretraining, I find enormous differences—linear probes return validation accuracies within a roughly 2% interval. This is significant on the scale of differences between results reported in published papers. Pretraining losses are also different, with relative differences between two runs more than 1e-7. But this is after many iterations of training, so the issue is in optimization.

I’ve observed that switching from the transformer-default AdamW to SGD dispels this issue, at the cost of learning far worse representations. So I wonder if the reparameterization in AdamW is actually too sensitive to numerical differences coming from different backends that give different gradients? You can somewhat get around this by increasing eps in AdamW, but making it large enough to fix the problem essentially means barely reparameterizing, and indeed performance degrades enormously.

Is this normal? Documented anywhere? If not, any suggestions for tracking down the actual issue? I would massively appreciate help here.

(I can post a working example, but I will hold off since it’d be a fair amount of code and this might just be a known fact.)