Logits mismatch between PyTorch inference and manual implementation

I’m trying to manually reproduce the inference forward-pass to understand exactly how quantized inference works. To do so, I trained and quantized a model in PyTorch using QAT, manually simulate the forward pass, and compare the outputs to a PyTorch inference.

The activations, however, start to diverge right after the first layer, as some channels differ by a few units (~5–15 difference). I get why values below the zero point (128 in this case) show up as 128, since the PyTorch logits are captured before ReLU, but that does not explain other discrepancies:

🔸 PyTorch Quantized Model Output
INT8 logits : [[182, 122, 163, 129, 113, 114, 165, 105, 139, 139, 152, 113, 179, 148, 159, 132]]

🔹 My Golden Model Simulation
layer 0 unsigned logits : [[182, 135, 163, 147, 128, 128, 183, 128, 152, 128, 168, 130, 179, 164, 140, 128]]

This is the script I used for inference:

import torch
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn import datasets


# ---------- Input (floating point) ----------
model_int8 = torch.load("trained_qat_model2.pth", map_location="cpu")
input_raw = np.array([5.3, 3.7, 1.5, 0.2], dtype=np.float32)  #setosa
input_raw = np.array([7.0, 3.2, 4.7, 1.4], dtype=np.float32) #versicolor
#input_raw = np.array([7.7, 3.8, 6.7, 2.2], dtype=np.float32) #virginica


# ---------- StandardScaler ----------
iris = datasets.load_iris()
scaler = StandardScaler().fit(iris.data)
input_float = scaler.transform([input_raw])[0]


###############################################################


capture = {}

def dump_first_layer(mod, inp, out):

    x_q = inp[0]
    w_q = mod.weight()
    b   = mod.bias() 
    
    capture["x_int8"]     = x_q.int_repr()                  # int8 input
    capture["s_x"]        = x_q.q_scale()
    capture["zp_x"]       = x_q.q_zero_point()

    capture["w_int8"]     = w_q.int_repr()                  # int8 weights
    capture["s_w"]        = w_q.q_per_channel_scales()      # fp32 weight scales
    capture["zp_w"]       = w_q.q_per_channel_zero_points() # int (all zero here – symmetric)

    capture["bias_fp32"]  = b.clone()

    capture["y_int8"]     = out.int_repr()                  # int8 after requant
    capture["s_y"]        = out.q_scale()
    capture["zp_y"]       = out.q_zero_point()


# ---------- Pytorch inference  ----------
hook = model_int8.fc1.register_forward_hook(dump_first_layer)
test_input = torch.tensor([input_float], dtype=torch.float32)
logits_fp32 = model_int8(test_input)       
hook.remove()                            

print()
print("🔸 PyTorch Quantized Model Output")
print("input     :", capture["x_int8"])
print("INT8 logits :", (capture["y_int8"]).tolist())
###############################################################

# ---------- Simulate inference ----------
print()
print("🔹 Golden Model Simulation")

# ---------- real uint8 activations & first-layer params ----------
x_u8 = capture["x_int8"].numpy() # 0…255
zp_x = int(capture["zp_x"])      # 128 for our symmetric model 
s_in = float(capture["s_x"])

layers = [m for m in model_int8.modules()
          if isinstance(m, torch.nn.quantized.Linear)]

for i, layer in enumerate(layers):
    # ------- weight & per-output parameters -------
    w     = layer.weight().int_repr().numpy().astype(np.int8)
    s_w   = layer.weight().q_per_channel_scales().numpy()      
    b_fp  = layer.bias().detach().numpy()                      
    s_out = float(layer.scale)                                 
    zp_y  = int(layer.zero_point)                              

    # ------- 1. signed-domain input -------
    x_s8  = x_u8.astype(np.int16) - zp_x                       # [-128,127]
    
    # ------- 2. int32/64 MAC -------
    acc   = x_s8.astype(np.int32) @ w.T.astype(np.int32)     # [out]
    
    # ------- 3. integer bias before requant -------
    b_int = np.round(b_fp / (s_in * s_w)).astype(np.int32)
    acc  += b_int

    # ------- 4. requant (float form; ±1 LSB accurate) -------
    mult = (s_in * s_w) / s_out
    out_s32 = np.round(acc * mult).astype(np.int32)

    # ------- 5. add zp_y, clamp, optional ReLU -------
    out_u8  = out_s32 + zp_y
    out_u8  = np.clip(out_u8, 0, 255).astype(np.uint8)
    if i != len(layers)-1:                  
        out_u8[out_u8 < zp_y] = zp_y

    # ------- Print -------
    unsigned_logits = out_u8.astype(np.uint8)
    print(f"layer {i} unsigned logits :", unsigned_logits.tolist())

    # ------- Feed next layer -------
    x_u8, zp_x, s_in = out_u8, zp_y, s_out

This is the script I used to train the model:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.quantization
from torch.utils.data import TensorDataset, DataLoader
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
#from quant_model import QuantizableFNN, xQuantizableFNN
from torch.quantization.observer import PerChannelMinMaxObserver, MovingAverageMinMaxObserver
from torch.ao.quantization import (
    MovingAverageMinMaxObserver, FakeQuantize, QConfig,
    default_per_channel_weight_fake_quant)

class QuantizableFNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.fc1 = nn.Linear(4, 16)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(16, 3)

    def forward(self, x):
        x = self.quant(x)
        x = self.relu1(self.fc1(x))
        x = self.fc2(x)
        x = self.dequant(x)
        return x



# Load and preprocess dataset
iris = datasets.load_iris()
X = StandardScaler().fit_transform(iris.data)
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
train_loader = DataLoader(TensorDataset(X_train_tensor, y_train_tensor), batch_size=16, shuffle=True)


# Train with QAT
# Custom symmetric per-channel qconfig
model_fp32 = QuantizableFNN()

act_fake_quant = FakeQuantize.with_args(
        observer=MovingAverageMinMaxObserver,
        qscheme=torch.per_tensor_symmetric, 
        reduce_range=False
        )

model_fp32.qconfig = QConfig(
        activation=act_fake_quant,
        weight=default_per_channel_weight_fake_quant)


torch.quantization.prepare_qat(model_fp32, inplace=True)

optimizer = optim.Adam(model_fp32.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(100):
    correct = 0
    total = 0
    running_loss = 0.0

    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model_fp32(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Accuracy calculation
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        running_loss += loss.item()

    acc = 100.0 * correct / total
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1:03d} | Loss: {avg_loss:.4f} | Accuracy: {acc:.2f}%")


# Convert and save
model_int8 = torch.quantization.convert(model_fp32.eval(), inplace=False)
torch.save(model_int8, "trained_qat_model2.pth")
print("✅ Quantized model saved as 'trained_qat_model2.pth'")

Hi @greifswald we are moving away from these older eager mode and fx graph mode quantization stack, so questions related to these will have low priority. But to respond to your question, it seems you are simulating the quantized linear op implemented in pytorch, it’s definitely possible that there are differences in implementation etc. I think

if you want to use pytorch quantization, please use our newer stack: Quantization — PyTorch main documentation, that is also recently moved from pytorch to torchao: ao/torchao/quantization/pt2e at main · pytorch/ao · GitHub

1 Like