I am trying to implement write a simple quantized tensor linear multiplication. Assuming the weight matrix w3
of shape (14336, 4096) and the input tensor x
of shape (2, 512, 4096) where first dim is batch size. When using normal linear function it works fine and the output has shape (2,512, 14336). But when using quantizing the tensors and using the quantized linear function, pytorch returns error.
Exception has occurred: RuntimeError
shape '[1, 2, 4096]' is invalid for input of size 4194304
out_q = QF.linear(x_q, w3_q)
RuntimeError: shape '[1, 2, 4096]' is invalid for input of size 4194304
The code is as follows:
import torch
import torch.ao.nn.quantized.functional as QF
import torch.nn.functional as F
loaded_data = torch.load('ffn_w3_example.pt')
w3 = loaded_data['ffn_w3'].type(torch.float16).to('cuda')
x = loaded_data['x'].type(torch.float16).to('cuda')
original_out = loaded_data['out'].type(torch.float16).to('cuda')
out_noquant = F.linear(x,w3)
def scale_zpt_compute(inp, Q_MAX, Q_MIN):
scale = (inp.max() - inp.min()) / (Q_MAX - Q_MIN)
zero_point = torch.round(- torch.min(inp) / scale) + Q_MIN
return scale, zero_point
# quantize weight with qint8
Q_MAX = 127.0
Q_MIN = -127.0
scale, zero_point = scale_zpt_compute(w3, Q_MAX, Q_MIN)
w3_q = torch.quantize_per_tensor(w3.type(torch.float32), scale, zero_point,dtype=torch.qint8)
# quantize input with quint8
# apparently the quantized.nn.linear only supports quint8 as input
Q_MAX = 255.0
Q_MIN = 0.0
scale, zero_point = scale_zpt_compute(x, Q_MAX, Q_MIN)
x_q = torch.quantize_per_tensor(x.type(torch.float32), scale, zero_point, dtype=torch.quint8)
out_q = QF.linear(x_q, w3_q)
My torch version shows: '2.4.1+cu121'
. It seems that the quantized linear QF.linear
does not provide the exact similar functionality to F.linear
. The shape shown in the error message suggests that for some internal calculation it transposes the batch size dimension to second place and for some reason ignores the 512 dims setting it to 1, which will cause the exception that it won’t match the total number of elements. Could you please guide me on this matter why this is happening?