Limitations of Int8 QAT for Linear Layers

Hello,

I am trying to create a simple example of a linear layer model using int8 weights and int8 activations trained with QAT. I’ve defined my own custom quantization config. I’ve noticed something odd though. If I have ≤16 input nodes, my weights are not quantized. If I have more than 16 input nodes, my weights are quantized! This is challenging for me, as the model I eventually want to deploy will have layers that have fewer than 16 input nodes. Below is a minimal working example displaying this behavior:

import torch as th
import torch.nn as nn
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
from torchao.quantization.qat import QATConfig, IntxFakeQuantizeConfig

# Model configuration
IN_NODES = 16
OUT_NODES = 1
USE_BIAS = False

# Training configuration
BATCH_SIZE = 32
NUM_EPOCHS = 5


class Model(nn.Module):
    def __init__(self, in_nodes, out_nodes, use_bias):
        super(Model, self).__init__()
        self.linear = nn.Linear(in_nodes, out_nodes, bias=use_bias)

    def forward(self, x):
        return self.linear(x)


def train(model, in_nodes, out_nodes, batch_size, num_epochs):
    optimizer = th.optim.Adam(model.parameters(), lr=0.01)
    for epoch in range(num_epochs):
        x = th.randn(batch_size, in_nodes)
        y = th.randn(batch_size, out_nodes)
        optimizer.zero_grad()
        preds = model(x)
        loss = th.nn.functional.mse_loss(preds, y)
        loss.backward()
        optimizer.step()
    return model

# QAT prepare: int8 activations + int8 weights (fake quant during training)
activation_config = IntxFakeQuantizeConfig(
    th.int8,
    "per_channel",
    is_symmetric=False,
    is_dynamic=True,
)
weight_config = IntxFakeQuantizeConfig(
    th.int8,
    "per_channel",
    is_symmetric=True,
    is_dynamic=False,
)

# Input nodes is 16
model = Model(IN_NODES, OUT_NODES, USE_BIAS)
quantize_(
    model,
    QATConfig(
        activation_config=activation_config,
        weight_config=weight_config,
        step="prepare",
    ),
)
# Train
model = train(model, IN_NODES, OUT_NODES, BATCH_SIZE, NUM_EPOCHS)
# Convert to quantized int8 model after training
quantize_(model, QATConfig(Int8DynamicActivationInt8WeightConfig(), step="convert"))
print(f"Weights are not quantized for <17 input nodes:\n{model.linear.weight}")

# Input nodes are 17
IN_NODES = 17
model = Model(IN_NODES, OUT_NODES, USE_BIAS)
quantize_(
    model,
    QATConfig(
        activation_config=activation_config,
        weight_config=weight_config,
        step="prepare",
    ),
)
model = train(model, IN_NODES, OUT_NODES, BATCH_SIZE, NUM_EPOCHS)
quantize_(model, QATConfig(Int8DynamicActivationInt8WeightConfig(), step="convert"))
print(f"\nWeights are quantized for >=17 input nodes:\n{model.linear.weight}")

Corresponding output:

Weights are not quantized for <17 input nodes:
Parameter containing:
tensor([[ 0.1410, -0.0185, -0.1361,  0.0216,  0.2207,  0.0134, -0.0426,  0.0193,
          0.0845, -0.0320,  0.0028, -0.0894,  0.1710,  0.0796,  0.0805,  0.0795]])

Weights are quantized for >=17 input nodes:
LinearActivationQuantizedTensor(AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=tensor([[  76,  -10,  -48,  -49,  -81,   59,   81,  -27,  127,   -7, -117,   58,
          -25,   51,   18,   86,   83]], dtype=torch.int8)... , scale=tensor([0.0019])... , zero_point=None... , _layout=PlainLayout()), block_size=(1, 17), shape=torch.Size([1, 17]), device=cpu, dtype=torch.float32, requires_grad=False), <function _int8_symm_per_token_reduced_range_quant at 0x13b381ee0>, quant_kwargs={}))

Is there a gist somewhere of training a simple model with QAT using int8 weights and int8 activations? What am I doing wrong in my own example? Why does PyTorch refuse to quantize linear layers which have fewer than 17 input nodes?

Package settings (Python 3.12.12):

executorch              1.1.0
pytorch-tokenizers      1.1.0
torch                   2.10.0
torchao                 0.15.0

Thank you!