Quantization of pytorch models

@ptrblck
How do I quantize my model to FP16 after training it normally in pytorch.

It depends on what kind of quantization you are talking about. The answer will be different if you are using PTQ or QAT.

PTQ

If you are just asking about post-training quantization (PTQ), you can simply cast the data types to torch.float16, like so:

from torch import nn

# This is for reference
model = nn.Linear(3, 3)
x = torch.randn(128, 3)
y = model(x)

# This is the quantized model
model_16 = copy.deepcopy(model)
model_16.weight = nn.Parameter(model_16.weight.half())
model_16.bias = nn.Parameter(model_16.bias.half())

# This is the quantized computation
x_16 = x.half()
y_16 = model_16(x_16)

# Check the quantization error:
x_norm = torch.norm(x - x_16)
y_norm = torch.norm(y - y_16)

print(f'x error norm: {x_norm:.2f}, x abs max: {x.abs().max():.2f}')
print(f'y error norm: {y_norm:.2f}, y abs max: {y.abs().max():.2f}')

x_sqnr = 20 * torch.log10(torch.norm(x) / (torch.norm(x - x_16)))
y_sqnr = 20 * torch.log10(torch.norm(y) / (torch.norm(y - y_16)))

print(f'x sqnr: {x_sqnr:.2f} dB')
print(f'y sqnr: {y_sqnr:.2f} dB')

## Results:
# x error norm: 0.00, x abs max: 3.13
# y error norm: 0.01, y abs max: 2.30
# x sqnr: 73.35 dB
# y sqnr: 67.21 dB

QAT

For QAT, please, refer to the Quantization — PyTorch 1.8.0 documentation (search for the “fakequantize” and “fake quantization”.

1 Like

we are providing a way to do fp16 static quantization in fx graph mode quantization as well, it is ready in master. you can find an example of related tests:

1 Like