I’m trying to manually reproduce the inference forward-pass to understand exactly how quantized inference works. To do so, I trained and quantized a model in PyTorch using QAT, manually simulate the forward pass, and compare the outputs to a PyTorch inference.
The activations, however, start to diverge right after the first layer, as some channels differ by a few units (~5–15 difference). I get why values below the zero point (128 in this case) show up as 128, since the PyTorch logits are captured before ReLU, but that does not explain other discrepancies:
🔸 PyTorch Quantized Model Output
INT8 logits : [[182, 122, 163, 129, 113, 114, 165, 105, 139, 139, 152, 113, 179, 148, 159, 132]]
🔹 My Golden Model Simulation
layer 0 unsigned logits : [[182, 135, 163, 147, 128, 128, 183, 128, 152, 128, 168, 130, 179, 164, 140, 128]]
This is the script I used for inference:
import torch
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn import datasets
# ---------- Input (floating point) ----------
model_int8 = torch.load("trained_qat_model2.pth", map_location="cpu")
input_raw = np.array([5.3, 3.7, 1.5, 0.2], dtype=np.float32) #setosa
input_raw = np.array([7.0, 3.2, 4.7, 1.4], dtype=np.float32) #versicolor
#input_raw = np.array([7.7, 3.8, 6.7, 2.2], dtype=np.float32) #virginica
# ---------- StandardScaler ----------
iris = datasets.load_iris()
scaler = StandardScaler().fit(iris.data)
input_float = scaler.transform([input_raw])[0]
###############################################################
capture = {}
def dump_first_layer(mod, inp, out):
x_q = inp[0]
w_q = mod.weight()
b = mod.bias()
capture["x_int8"] = x_q.int_repr() # int8 input
capture["s_x"] = x_q.q_scale()
capture["zp_x"] = x_q.q_zero_point()
capture["w_int8"] = w_q.int_repr() # int8 weights
capture["s_w"] = w_q.q_per_channel_scales() # fp32 weight scales
capture["zp_w"] = w_q.q_per_channel_zero_points() # int (all zero here – symmetric)
capture["bias_fp32"] = b.clone()
capture["y_int8"] = out.int_repr() # int8 after requant
capture["s_y"] = out.q_scale()
capture["zp_y"] = out.q_zero_point()
# ---------- Pytorch inference ----------
hook = model_int8.fc1.register_forward_hook(dump_first_layer)
test_input = torch.tensor([input_float], dtype=torch.float32)
logits_fp32 = model_int8(test_input)
hook.remove()
print()
print("🔸 PyTorch Quantized Model Output")
print("input :", capture["x_int8"])
print("INT8 logits :", (capture["y_int8"]).tolist())
###############################################################
# ---------- Simulate inference ----------
print()
print("🔹 Golden Model Simulation")
# ---------- real uint8 activations & first-layer params ----------
x_u8 = capture["x_int8"].numpy() # 0…255
zp_x = int(capture["zp_x"]) # 128 for our symmetric model
s_in = float(capture["s_x"])
layers = [m for m in model_int8.modules()
if isinstance(m, torch.nn.quantized.Linear)]
for i, layer in enumerate(layers):
# ------- weight & per-output parameters -------
w = layer.weight().int_repr().numpy().astype(np.int8)
s_w = layer.weight().q_per_channel_scales().numpy()
b_fp = layer.bias().detach().numpy()
s_out = float(layer.scale)
zp_y = int(layer.zero_point)
# ------- 1. signed-domain input -------
x_s8 = x_u8.astype(np.int16) - zp_x # [-128,127]
# ------- 2. int32/64 MAC -------
acc = x_s8.astype(np.int32) @ w.T.astype(np.int32) # [out]
# ------- 3. integer bias before requant -------
b_int = np.round(b_fp / (s_in * s_w)).astype(np.int32)
acc += b_int
# ------- 4. requant (float form; ±1 LSB accurate) -------
mult = (s_in * s_w) / s_out
out_s32 = np.round(acc * mult).astype(np.int32)
# ------- 5. add zp_y, clamp, optional ReLU -------
out_u8 = out_s32 + zp_y
out_u8 = np.clip(out_u8, 0, 255).astype(np.uint8)
if i != len(layers)-1:
out_u8[out_u8 < zp_y] = zp_y
# ------- Print -------
unsigned_logits = out_u8.astype(np.uint8)
print(f"layer {i} unsigned logits :", unsigned_logits.tolist())
# ------- Feed next layer -------
x_u8, zp_x, s_in = out_u8, zp_y, s_out
This is the script I used to train the model:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.quantization
from torch.utils.data import TensorDataset, DataLoader
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
#from quant_model import QuantizableFNN, xQuantizableFNN
from torch.quantization.observer import PerChannelMinMaxObserver, MovingAverageMinMaxObserver
from torch.ao.quantization import (
MovingAverageMinMaxObserver, FakeQuantize, QConfig,
default_per_channel_weight_fake_quant)
class QuantizableFNN(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
self.fc1 = nn.Linear(4, 16)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(16, 3)
def forward(self, x):
x = self.quant(x)
x = self.relu1(self.fc1(x))
x = self.fc2(x)
x = self.dequant(x)
return x
# Load and preprocess dataset
iris = datasets.load_iris()
X = StandardScaler().fit_transform(iris.data)
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
train_loader = DataLoader(TensorDataset(X_train_tensor, y_train_tensor), batch_size=16, shuffle=True)
# Train with QAT
# Custom symmetric per-channel qconfig
model_fp32 = QuantizableFNN()
act_fake_quant = FakeQuantize.with_args(
observer=MovingAverageMinMaxObserver,
qscheme=torch.per_tensor_symmetric,
reduce_range=False
)
model_fp32.qconfig = QConfig(
activation=act_fake_quant,
weight=default_per_channel_weight_fake_quant)
torch.quantization.prepare_qat(model_fp32, inplace=True)
optimizer = optim.Adam(model_fp32.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(100):
correct = 0
total = 0
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model_fp32(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Accuracy calculation
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
running_loss += loss.item()
acc = 100.0 * correct / total
avg_loss = running_loss / len(train_loader)
print(f"Epoch {epoch+1:03d} | Loss: {avg_loss:.4f} | Accuracy: {acc:.2f}%")
# Convert and save
model_int8 = torch.quantization.convert(model_fp32.eval(), inplace=False)
torch.save(model_int8, "trained_qat_model2.pth")
print("✅ Quantized model saved as 'trained_qat_model2.pth'")