Expending PyTorch with lower than 8-bit Quantization

I am interested in using PyTorch for 1-bit neural network training. It seems to me that pytorch now only supports dtype=qint8. I am wondering if there is an good guide for PyTorch dtype system and how to expanding it.


cc @raghuramank100 has a diff out, but it’s not landed yet: https://github.com/pytorch/pytorch/pull/33743

1 Like

Thanks @jerryzh168, that’s good info to keep an eye on.

https://github.com/pytorch/pytorch/pull/33743 give a nice touch on the sub-8-bit quantization. However, if I want to do some 1-bit quantization which quantizes the feature map and weight matrices into {-1, 1}. This may requires further changes in the qscheme. I am guessing that will require me add some PyTorch intrinsics in ATen? Or there is a better way to accomendate that need?

In any case, I am looking forward to see 33743 land soon.

right, if it is {-1, 1} it is not probably not affine quantization, what would you quantize 0 into?
I think you’ll probably need to extend qscheme and implement a new quantizer(https://codebrowser.bddppq.com/pytorch/pytorch/aten/src/ATen/quantized/Quantizer.h.html) to support this.

if it is {-1, 1} it is not probably not affine quantization, what would you quantize 0 into?

I am trying to follow the XNOR-Net paper and its variant which quantizes the filter weights W into B such that W approximates aB where a is a real valued scalar. Thus, by the equation (4) in that paper, Bi = +1 if Wi >= 0 and Bi = -1 if Wi < 0.

I am looking at the ATen code, it seems to be that if we want to add the support such binary quantization scheme, we will have to recompile PyTorch locally in order to have the new ATen library takes effect. Is there any way I can do it using the official PyTorch distribution without re-compilation?

yeah, that is correct, you’ll need to implement a new quantization scheme, and probably adding a new quantize function https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L3809 etc. we don’t have a good way to simulate this in current codebase I think. since this level of support is targeted at production.

However, if you just want to simulate the accuracy of the model in training, you might get away with less complete support in core, for example, just add a new quantization scheme and add support for it in the fake quantize module: https://github.com/pytorch/pytorch/blob/master/torch/quantization/fake_quantize.py#L92.

But anyways you’ll probably need to add a new quantization scheme https://github.com/pytorch/pytorch/blob/master/c10/core/QScheme.h

Thanks for the detailed explanation. This is very helpful.

I mainly target at evaluating models in 1-bit precision with QAT, so I guess the second approach is good enough (for now). It seems that I still will need touch the Aten/c10 code and recompile it locally. If so, is this something the PyTorch team interested to merge upstream in the future (assuming my code meets the quality requirement)?

On a related note, Nvidia’s new A100 architecture will support binary (1-bit) precision.

Acceleration for all data types, including FP16, BF16, TF32, FP64, INT8, INT4, and Binary

This is not too far away from the production. Details can also be found in their white paper. It feel like an interesting direction for PyTorch community to explore and will be meaningful if we can support in the long-run.

yeah low precision quantization is definitely something we want to pursue, but it may not need extra quantization schemes to support them, although we’ll need new data type support. for example, we can have per tensor affine quantization(existing quantization scheme) with Int4(new data type).

In the case of 1-bit precision to {1, -1}, we also need a new quantization scheme since it is not affine quantization. if the integer values are consecutive, e.g. {-1, 0, 1}, {0, 1}, I think we should be able to represent it with per tensor affine quantization and a new INT1/INT2 data type.

I agree with the comment of sub-8-bit quantization. We should be able to support 2-7 bit using the existing infrastructure with some new data types INT2-7.

In the case of 1-bit (binary), you can represent {-1, 1} in {0, 1} by assigning -1 to 0. In fact, that’s what will be implemented in hardware. However, that means you will replace multiplication by XNOR. This change results in a separate set of operators/functionals/modules need to be overload for binary network. From math point of view, I would like to see BNN implemented in this way (which exactly match the hardware). However, it is a lot of work and hard to be maintained (separately from all other NN modules). Frankly, an engineer will argue that there is no significant benefit of doing so. I feel like a new data type BINT1 for {-1, 1} (to be different from INT1 for {0, 1}) is a better choice.

I will try to experiment with this idea and submit issue/PR in the coming months.

1 Like