AMP during inference

Does AMP speed up inference, or is it just for training? I used to think the reduced precision was only for gradients (and therefore irrelevant for inference) but I’ve seen mentions of choosing other implementations for the forward too, which sounds like it might be designed for inference-only usecases too.

Secondly, just to make sure - AMP might cause a slight decrease in accuracy compared to fp32, but considerably less than moving both the model and the data to fp16?

Hi @yiftach , my guess is that AMP isn’t really needed for inference, as we don’t need to keep track of model activations as we pass through the layers for 1 forward pass (we can discard them as we go higher up the layers), whereas we would need to keep all of the activations in memory for the backward pass during training. As such the memory savings for inference using mixed precision (in the sense that AMP implements) isn’t very substantial compared to what it would be during training.

That said, during inference, my understanding is that the bottleneck is the memory requirements for the kv-cache (just a subset of the activations), for which I believe some engines have quantized options available for reducing overall memory usage, like the FP8 kv-cache in vllm:

I’m still quite new to all these, so please let me know if this doesn’t make sense!