Thank you very much for reading my post. I have conducted a series of research; however, I could not find a way to solve this issue. I really appreiciate if you can provide any isnights.
Main issue:
The mode’s predictions are significantly different depending on the input size. It seems that this is from variable scale of roundoff errors directly related to input shapes.I am looking for measures to make my model robust to floating point precision.
My thought:
The model is trained to perform inference always with the same input size of (batch_size, max_seq_len, model_dim) by padding. However, because the input shape varies during inference and this leads to slightly different floating point operations and roundoff error scales, the model’s inference performance is unstable. Model’s performance depends on the input shape. For example, padding inputs to max_seq_len yields more accurate inference results, while very short inputs are related to poorer performance.
Question:
- Why is my model so sensitive to floating point precisions?
- Are there any measures to make my model robust to floating point roundoff errors, or robust to input shape varieties?
The details follow below.
=== Details =========================================================
The details of the main issue:
For instance, performing inference using the same input with or without paddings to the max_seq_len returns noticablly different results.
I am using a decoder-only transoformer model. It seems that this behavior roots from the accumulation (or, propagation) of small round-off errors that are expected in floting point arithmetics. I acknowledge that such errors usually does not affect model performance; however, my model seems to be very sensitive to floating point precision.
I noticed this when I implemented KV-cache. Attention weigths are computed as query@key.T, while it is computed as query[…, -1:, :]@key.T with KV-cache. Because you only care about the last row, these results should be the same during autoregressive inference. The model’s performance was well without using KV-cahce; however, the performance degraded significantly after implementing KV-cache. I conductes extensive debugging, and reached to the conclusion that this results from the different input shape of inputs. I made sure that my implementation is mathmatically correct.
The model:
I am afraid I cannot make my model public (partial codes can be shown if requested), however, the model is a very typical one with 12 decoder-only transformer layer stack with a last feed-forward layer that produces logits over about 14,000 vocabularies. The main structure is identical to GPT-2 (including LayerNorm, residual connection etc.). The model params are initialized with N(0, 0.02). The total trainable params are roughly 100 millions (the model is relatively small).
Training:
Just like GPT-2, the loss criteria is CrossEntropyLoss. The training was done with mixed-precision with bfloat16, using torch.amp.autocast. After 100 epochs, the loss plateaued around 2.3 as expected. During the training, all short inputs are padded to the max_seq_len (1024). Because more than 50% of the data are short, most data are padded during training.
Inference:
Inference is also performed with bfloat16 using torch.amp.autocast(). The model performs autoregressive inference. Padding is not applied to inputs even if the sequence is shorter than the max_seq_len (which I think is a common practice. KV-cache is optional. If the input is longer than the max_seq_len, KV-cache is not used. (As I mentioned earlier, the inference results are similar, but significantly different with or without cache, which seems to root from the different input shape during matrix multiplications.)
Thank you very much again for your attention.