High Quantization-Aware training memory consumption

I have noticed that once Quantization-Aware Training starts memory consumption increases significantly. E.g. for a network that consumes 2GB during training, I see memory consumption jumping to 4.5GB once QAT starts.
While in part, this can be explained due to the various statistics that need to be stored, the difference in memory consumption is much larger than what these stats occupy. I am guessing that this might be due to additional intermediate activation maps that need to be kept around, but I was hoping that someone more informed could drop in and clarify if this is the reason.

Hi @Georgios_Georgiadis, one known problem is that fake_quantize modules are currently implemented as additional nodes in the computation graph, so their additional outputs (the fake_quantized versions weights and activations) contribute to the memory overhead during training. We have plans to improve the memory overhead in the future by adding fused fake_quant kernels for common layers such as conv and linear.