I have implemented a transformer encoder-decoder architecture for regression, where I give the encoded sequence to the transformer decoder and the decoder runs step-by-step, i.e. it starts with [encoded,zeros] and predicts the first token, then continues with [encoded,zeros,token1_pred] and predicts the second token and so on.
When examining the outputs, I have noticed that even with a single multi head attention layer as the decoder I have numerical errors for the same inputs, e.g. the decoder always outputs the full sequence of tokens but we only use the last one, BUT the previous ones must be the same, as the inputs are exactly the same. By examining these this is not the case, and these errors start from 10e-8 but as they are propagated in the layers of the transformer and the timesteps of the decoding the error becomes so large that the outputs does not make any sense.
I noticed the same behavior in Karpathy’s nanoGPT repo where the same thing happens with the logits which are the continuous values, but since a softmax is applied there, this is never reflected on the outputs because the indices remain the same.
The results of floating-point operations are subject to numerical precision errors even if they mathematically are the same.
This can be a problem (I’ve seen this e.g. with things like Llama’s ROPE and bf16 vs. fp16) but there is only so much that you can do about it since the deviations up to numerical precision are an implementation detail of the operations.
Yes, I noticed that in my case it cannot be avoided as it seems. It is weird because I hadn’t encountered this behavior with LSTMs because they keep their state, but transformers are very different and I notice some huge deviations.
Part of it might be that LSTMs were typically operated with 32bit precision and now we have 16 bit or tf32 a lot. Also, the classification of operators for “fp16-safety” in automatic multi-precision hints towards the softmax in the attention and the LayerNorms being potentially sensitive while LSTMCell is safe (Automatic Mixed Precision package - torch.amp — PyTorch 2.1 documentation).
Yes, I only play with 32bit, but I intended to start using torch.amp so this is a good thing to keep in mind thanks!