Is there a way to perform QAT or mixed precision training only in 8 bits. Afaik, currently qat works with fake quants which store the weights in 32 bits and we only see the int8 weights after converting the model.
I’m working on a use case where the full training pipeline should be completed in 8 bits. This is because some edge hardware only support 8 bit operations. That means both the forward and backward passes should be done in only 8 bits and the weights should be stored in 8 bits as well. I’m aware of 8 bit optimizers (bitsandbytes which solves the backward pass 8 bit implementation, but I’m not sure how to do forward pass and store all the weights as 8 bits during training of a pytorch model.
Any suggestions?