Convert normal tensor to quantized tensor and use quantized backend for matrix operations

Hi, I’m trying to build out a quantization module for a project and implement it from a lower level. I looked at the source code for the Observers and noticed the scale and zero_point are calculated in a way separate from some of the research papers I’ve read (e.g. for symmetric scale it’s (2 ^ (bits - 1) - 1) / max_x but in PyTorch it’s the max_x / ((quant_max - quant_min) / 2; let quant_min = - quant_max = 2 ** (bits - 1) - 1. It seems the methods are trying to accomplish the same thing but they’re an inverse of one another which in turn causes the wrong values (just zeros) when using torch.quantize_per_tensor() with my own calculated values.

My first question is wondering if it’s possible to take a weight that’s been quantized through my custom methods and then just convert it into a pytorch version that has a dtype of quint8 or qint8? Even better would be taking that weight and turning it into one of the supported quantized modules.

Next, is it possible to use the dot product for the quantized backend? I’ve tried before but I get errors (e.g. np.dot(qinput, qweight)) saying it’s not supported. I know there’s a FloatFunctional class that enables some operations to be computed in the quantized domain (e.g. torch.add(), torch.cat(), torch.mul()), but torch.dot() is not supported (even though it says it is which is misleading).

Below is an example of what I want to accomplish:

import torch
import torch.nn as nn

def compute_scale(beta, bits=8):
    return (2 ** (bits - 1) - 1) / beta

def quantize_symmetric(tensor, bits=8):
    beta = torch.max(torch.abs(torch.min(tensor)), torch.abs(torch.max(tensor))).item()
    scale = compute_scale(beta, bits)
    return torch.clamp(torch.round(tensor * scale), -2 ** (bits - 1) + 1, 2 ** (bits - 1) - 1)

def main():

    linear = nn.Linear(100, 10)
    weight = linear.weight
    qweight = quantize_symmetric(weight, bits=8)
    print(qweight)
    # use qweight with quantized backend: e.g. torch.dot(some_input, qweight) -> quantized_tensor

    #  qlinear = CustomQuantizedLinear(in_features, out_features, qweight)

if __name__ == "__main__":
    main()

Output:

tensor([[  35.,  -51.,  -64.,  -28.,  -99.,   66.,   94.,  115.,   12.,   43.,
          -38., -121.,   -5.,   80.,   78., -126.,  -16.,   96.,   82.,  -80.,
          -49.,  -61.,   81.,  116.,   32.,   49.,  -89.,   78.,   75.,    3.,
           74.,  -39.,   35.,    5., -113.,  -26.,  -31.,  -56.,  -42.,   89.,
           91.,  -46.,   -6.,   73.,  113.,  -96.,   29.,   19.,   20.,  -42.,
           59.,  -83.,   95.,   99.,  105.,  -92.,   13.,  -10.,   37.,   13.,
          111.,  -13.,   85.,   95.,  -15.,   90.,   -4., -125.,  -33.,   75.,
           -3.,   10.,  -87.,  -13.,  -10.,  -37.,  -85.,  123.,  108., -100.,
           -5.,  -13.,   76.,  -62.,  105.,   87.,   32., -113.,  -52.,  -59.,
          -55.,  -19., -104.,   27.,   88., -111.,  -24.,   91., -109., -101.],
        [  -6.,   89.,  -31.,  -48.,   -2.,   36.,   72.,  -87.,  111.,  -18.,
           22.,  100.,   -3., -111.,   53.,   29.,  -78.,  114.,  -51.,  -31.,
          -59., -102., -100.,    7.,  -98.,  -58.,  -36.,  107.,  -84.,  122.,
           43.,  -86.,   89.,  110.,   19.,  -74.,   96.,  -19.,   -5.,  -93.,
          -36., -110.,   81.,  -42.,  -11.,  -46.,   86.,  -17.,   85., -109.,
         -110.,   -2.,   38., -107.,  -32.,  -66.,   79.,   60.,  -73.,  -93.,
           16.,   57.,   37.,  -49.,    9.,   87.,   -3.,    1.,  120.,  114.,
           10.,   74., -127.,  -13.,   17.,  101.,   12.,  118.,  -84.,  -94.,
           -9.,  -50.,  100.,  -38.,  -70.,   61.,   74.,  -29.,  115.,  113.,
          -22.,   87.,  -63.,  -23.,   35., -121.,  -84.,   46.,   95.,  114.],
        [  77.,  -73.,  117.,   34.,    9., -101.,  -95.,  -55.,  -80.,   53.,
          108., -111.,  -17.,   -3.,  117.,   93.,   95.,  -45.,  112.,   25.,
           11., -106.,  -59., -117.,   95.,  -64.,  114.,  -22.,   16.,   29.,
           44.,   21.,  122.,   -7.,  -90.,  -72.,   79.,  -32.,  -60.,   81.,
           92.,  -43.,   50.,    2.,  -38.,  -21.,  117., -111.,   13.,   95.,
           51.,  113.,  -49.,   49.,   -2.,  -33.,  -34.,  -67.,  -94., -109.,
          105.,    6.,   55.,   90.,  -46.,   12.,  -74.,   47.,  -58.,  -91.,
           52.,  -65.,   86.,   10.,   66.,   77., -120.,   55.,   70.,   91.,
           51.,  -97.,  -31.,   46.,  -55.,   22.,   61.,  -41.,  -38.,   44.,
           55.,    5.,  -28.,  -66.,   32.,   69., -109.,  -28.,  -26.,  122.],
        [ -14.,    5.,  101.,  -51.,  -39., -119.,   -4.,  122.,   89.,  -96.,
          125.,  -55.,    5.,  -22.,  124.,    2.,  -31.,   99.,   42.,  -90.,
          -74.,   80.,  -27.,  -52.,   61.,  -80.,  -53.,    3.,  -46.,   87.,
          115., -102.,  -11.,  -74.,  -88.,  -78.,   85.,  -63.,   23.,   25.,
          -32.,  -79.,  109.,   93.,   10.,  -57.,   44.,   -3.,  -22., -106.,
            5.,   93., -119.,   47.,   32.,   51.,   63.,  -94., -118.,  -29.,
         -121.,   98.,  -11.,  -86.,   -2.,  -62.,  -14.,  106.,   69.,   24.,
          -51.,   54.,  127.,  121.,   17.,   77., -108.,  -93.,    6., -121.,
           27., -110.,   62.,  -23.,   43.,   18.,   49.,  -59.,  -21.,   97.,
          100.,   84.,  -32.,  -50.,   50.,   24.,  -19.,   76., -105.,   71.],
        [  62.,   -9.,  -55.,  125.,   22.,  117.,  -30.,  -27.,   -3.,  -10.,
          -81.,  121.,   81.,  -67.,  -54.,  107.,  -61., -106.,  -10.,  -92.,
          -59.,  119.,   23.,  -76.,   -7.,    5.,  122.,    7.,   49.,   -0.,
          -52.,   14.,  -97.,   51.,   56.,  -63.,   52.,   97.,  -99.,   41.,
          -71.,  -53.,  -27., -117.,   -6.,   87.,   -7.,   16.,    0.,  -55.,
           -6.,  -97.,  -36.,  -75.,  -61.,   21.,  -30.,   87.,   42.,   30.,
          -54.,   39.,  108.,  -41.,   91.,  -47., -101.,  -43.,   68.,  122.,
           62.,  -10.,  -98.,   26., -102.,   28.,  123.,  -84.,   19.,  -49.,
           13.,  -14.,   22.,  105.,   28., -102.,   89.,  -57., -118.,  -59.,
          -67.,   67.,  -41., -110.,   14.,  -96.,  -31.,  103.,   -5.,   83.],
        [  36., -108.,  -24.,  -24.,  119.,  115., -112.,  -48., -113., -101.,
         -107.,   15.,  110., -102.,  -55.,   78.,  111., -101.,   55.,  -64.,
          -13.,  -51.,  -87.,   94.,    9., -121., -121.,  111.,    3.,  -47.,
          113.,  -23.,  -76.,  -64.,    3.,   34.,  -46.,   73.,  -44.,    7.,
          -18.,   84.,   23.,  101.,  -94.,   -3.,  -27.,  -83.,  -74.,  -52.,
         -111.,  -36.,   12.,  -31.,  101.,  -69., -116.,  -65.,    6.,  127.,
          -20.,   36., -103.,   96., -112.,  -89.,  -89.,  114., -113.,   54.,
           42.,   -6.,   41.,   48., -112.,   90.,   45.,  104.,   52.,  -70.,
          -26.,   40.,  -61.,  -77.,   91.,   98.,   28.,  -64.,  124.,   15.,
          -47.,  -64., -120.,  -51.,   73.,  100.,  -91.,  116.,  -37.,   14.],
        [  95.,  -72.,  121., -112.,  126.,   82.,  111.,   40.,   93.,   98.,
           -1.,  116.,   59.,   34., -117.,   29., -119.,  -84., -124.,  -26.,
           20.,  -53.,    3.,  124.,  120., -120.,    6., -124.,  -42.,  122.,
            8.,   69.,  -29.,   95.,   80., -104., -125., -125.,   96.,  126.,
          -81.,  -30., -100.,  107.,   67.,   73.,   40.,  -63.,  -68.,  -19.,
         -110.,   87., -114., -122.,  103.,  113.,  -33.,  -57.,  -48.,  120.,
          -52.,   20.,  -12.,   81.,   88.,   37.,   23.,   38.,  -38.,  -90.,
          -65.,  -94.,  -58.,   96.,   10.,   11.,   -7.,  -78., -109.,   87.,
          -99.,  -94.,  -29.,   26.,   82., -101.,  -69.,   44.,  -87.,   45.,
          -35.,  -77.,  -36.,  -43.,  -58.,   19.,  -96.,  100.,  -60.,  -98.],
        [-106.,   89.,  -29.,   82.,   21.,   58.,   13.,   85.,  -28.,  125.,
          -47.,  107.,  -20.,  -99.,  -76.,   48.,  -34.,  -35.,  -18.,   72.,
           15.,  -81.,   75.,  -77.,   89.,  104., -125.,  -57.,  -79.,  -94.,
          117.,   44.,  -72.,  -46.,  105.,  112.,   28.,   61.,  -65.,   95.,
           88.,   51.,   17.,  -53., -114.,   86.,   69.,    3.,   42.,   27.,
           -8.,  -25.,  114.,  116.,  -79.,  -58.,  -50.,   46.,   -7.,   48.,
          -95.,  107.,   91.,  -11.,   10.,  -39.,  -14.,  -66.,   26.,  100.,
          -26.,  -15., -107.,   16.,  -25.,   16.,  -97.,  -58.,   48.,   93.,
            3., -115.,  -10.,  -62.,  100.,  -36.,   32.,   78.,   95.,  111.,
           90.,   40.,  -47.,  -12.,  -94.,   30.,  -52.,   82.,   41., -113.],
        [ -15.,   58.,  -41.,   66.,   87.,  -41.,   90., -117.,  108.,  -36.,
           21.,  -41.,  -77., -112.,  -75.,  122.,   25.,   51.,  -29., -120.,
            0., -112.,  -82.,  -26., -113., -105.,   64.,   31.,  -61.,  -52.,
          -74., -105.,  100.,  -59.,   11.,   -2., -106.,  -54.,  -77.,  -50.,
           79.,  -21.,  -84.,   17.,  -63., -100.,   91., -126.,   96.,  -40.,
           75.,  117.,  126., -104.,   88.,   97.,   24.,  -74., -106.,  -57.,
           57.,   79.,   13.,  -84.,  -56.,   -8.,   96.,   15.,  -48.,   -2.,
           76.,    0.,    1.,   61., -103.,   73.,  120.,   81.,   42.,   38.,
         -109.,  -99.,   98.,   27.,   60.,  -20.,  -61.,  -57.,   88.,   60.,
           56.,   29.,  -54.,   24.,  119.,   38.,  112.,   12.,  124.,  -28.],
        [  27.,   -2.,  110., -117., -113.,  -24.,  -29.,    9.,   81.,  -86.,
          -97.,   43., -102.,  -82.,   42., -103.,   -8.,  -55.,   30.,   37.,
           73.,  -67.,  -51.,  -18.,  -66.,    6.,   19.,  125.,  -32.,    1.,
          -18.,  -53.,   28.,  -22.,   -3.,  -50.,   39.,  -18.,   76.,   81.,
          -75.,   24.,  110.,  -41.,  -54.,   69., -116.,  -84.,  -10.,  114.,
          -56.,   87.,  -87.,  123., -106.,  106., -109.,  -25.,   71.,   93.,
          -67.,  -60.,   79.,  -17.,   43.,  -73.,  -47.,   20., -127.,   22.,
           78.,   -9.,   20.,  -15.,  107.,  -93.,   97.,  -20.,   -4.,  -27.,
          -82.,  100.,   -5.,  -92.,   42.,  -47., -116.,   71.,  104.,   48.,
           16.,   81.,   31.,    1., -124.,   77.,  -73.,  -69.,  -18.,  -13.]],
       grad_fn=<ClampBackward1>)