Quantization Aware Training in 8 bits only

Is there a way to perform QAT or mixed precision training only in 8 bits. Afaik, currently qat works with fake quants which store the weights in 32 bits and we only see the int8 weights after converting the model.

I’m working on a use case where the full training pipeline should be completed in 8 bits. This is because some edge hardware only support 8 bit operations. That means both the forward and backward passes should be done in only 8 bits and the weights should be stored in 8 bits as well. I’m aware of 8 bit optimizers (bitsandbytes which solves the backward pass 8 bit implementation, but I’m not sure how to do forward pass and store all the weights as 8 bits during training of a pytorch model.

Any suggestions?

I have the same question but seems like now there is no open-source code for int8 training.

we have a recent intern that did some experiements on int8 training, maybe you can take a look: GitHub - fufeisi/Usage-of-the-8bit-Quantization-in-Neural-Network-Training: This repo has the script to reproduce the experiments in project 'Usage of the 8bit Quantization in Neural Network Training'., it is comparing the memory savings of dynamically quantizing the activations in the backward pass and also quantizing the gradients in optimizers to 8 bit (same approach in bitsandbytes) I think.