pytorch quantized linear function gives shape invalid error

I am trying to implement write a simple quantized tensor linear multiplication. Assuming the weight matrix w3 of shape (14336, 4096) and the input tensor x of shape (2, 512, 4096) where first dim is batch size. When using normal linear function it works fine and the output has shape (2,512, 14336). But when using quantizing the tensors and using the quantized linear function, pytorch returns error.

Exception has occurred: RuntimeError
shape '[1, 2, 4096]' is invalid for input of size 4194304
    out_q = QF.linear(x_q, w3_q)
RuntimeError: shape '[1, 2, 4096]' is invalid for input of size 4194304

The code is as follows:

import torch
import as QF
import torch.nn.functional as F

loaded_data = torch.load('')
w3 = loaded_data['ffn_w3'].type(torch.float16).to('cuda')
x = loaded_data['x'].type(torch.float16).to('cuda')
original_out = loaded_data['out'].type(torch.float16).to('cuda')

out_noquant = F.linear(x,w3)

def scale_zpt_compute(inp, Q_MAX, Q_MIN):
    scale = (inp.max() - inp.min()) / (Q_MAX - Q_MIN)
    zero_point = torch.round(- torch.min(inp) / scale) + Q_MIN
    return scale, zero_point

# quantize weight with qint8
Q_MAX = 127.0
Q_MIN = -127.0
scale, zero_point = scale_zpt_compute(w3, Q_MAX, Q_MIN)
w3_q = torch.quantize_per_tensor(w3.type(torch.float32), scale, zero_point,dtype=torch.qint8)

# quantize input with quint8
# apparently the quantized.nn.linear only supports quint8 as input
Q_MAX = 255.0
Q_MIN = 0.0
scale, zero_point = scale_zpt_compute(x, Q_MAX, Q_MIN)
x_q = torch.quantize_per_tensor(x.type(torch.float32), scale, zero_point, dtype=torch.quint8)

out_q = QF.linear(x_q, w3_q)

My torch version shows: '2.4.1+cu121'. It seems that the quantized linear QF.linear does not provide the exact similar functionality to F.linear. The shape shown in the error message suggests that for some internal calculation it transposes the batch size dimension to second place and for some reason ignores the 512 dims setting it to 1, which will cause the exception that it won’t match the total number of elements. Could you please guide me on this matter why this is happening?

Hi @hafezmg48 , this is referencing old cpu-only quantization code. Are you trying to build on CPUs? Our new GPU-friendly APIs are over at GitHub - pytorch/ao: PyTorch native quantization and sparsity for training and inference .

In terms of why you see the error - it’s not clear from the snippet, but to debug this I would look at the source code of quantizated functional linear (Blaming pytorch/torch/nn/quantized/ at 3e1fc85b23f9f12ff2ba5be645841bde90dba14e · pytorch/pytorch · GitHub ) and see at which line your code stops giving sensical results.

Hi @Vasiliy_Kuznetsov . Thanks for reply. No, I am trying to do GPU inference as I have also tried to move the weights to cuda in first lines. I am relatively new to this. I will look up the links thanks.

got it, then I think ao/torchao/quantization at main · pytorch/ao · GitHub would be a good starting point

1 Like