Question about QAT to 4bit

Hello!

I am trying to quantize the model to 4bit. My torch version is 1.7.1
I have changed the quant_min and quant_max in qconfig.py, fake_quantize.py, and observer.py (like below)

    if backend == 'fbgemm':
        qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
                                                            quant_min=0,
                                                            quant_max=15,
                                                            reduce_range=True),
                          weight=default_per_channel_weight_fake_quant)
default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
                                                               quant_min=-8,
                                                               quant_max=7,
                                                               dtype=torch.qint8,
                                                               qscheme=torch.per_channel_symmetric,
                                                               reduce_range=False,
                                                               ch_axis=0)
if self.dtype == torch.qint8:
    if self.reduce_range:
        quant_min, quant_max = -4, 3
    else:
        quant_min, quant_max = -8, 7

else:
    if self.reduce_range:
        quant_min, quant_max = 0, 7
    else:
        quant_min, quant_max = 0, 15

I have checked that the range of weights in fake_quantize is correct (In fake_quantize I quantize the weight to check if it is correct or not)

But I get the max value of weights is 8 when I inference the model. It should be 7.

for layer in model.modules():
        if isinstance(layer, nn.quantized.Conv2d):
            w, b = layer._weight_bias()
            print(w.int_repr().max())

This is the result of the screenshot.

Is there anything else that needs to be chnaged?

Thanks!

Hi @wang_kevin , your setup looks right so far. Do you have a small reproducible code snippet you can share to help us look into it?

Hi @Vasiliy_Kuznetsov Thanks for the response!

I write the simple code and test get the same problem just like above.

Here is my simple code

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import os

os.environ["CUDA_DEVICE_OR_DER"] = "PCI_BUS_ID"  #see
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device('cuda')

class kernel_conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(kernel_conv, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()
    
    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        x = self.dequant(x)

        return x

class kernel_fc(nn.Module):
    def __init__(self, in_features, out_features):
        super(kernel_fc, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.fc = nn.Linear(in_features, out_features)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()
    
    def forward(self, x):
        x = self.quant(x)
        x = self.fc(x)
        x = self.relu(x)
        x = self.dequant(x)

        return x

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = kernel_conv(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = kernel_conv(6, 16, 5)
        self.fc1 = kernel_fc(16 * 5 * 5, 120)
        self.fc2 = kernel_fc(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = self.pool(self.conv1(x))
        x = self.pool(self.conv2(x))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)

        return x

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

if __name__ == '__main__':
    print("===> creating model <===")
    net = Net()
    net = net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    torch.quantization.fuse_modules(net, [["conv1.conv", "conv1.relu"],
                                            ["conv2.conv", "conv2.relu"],
                                            ["fc1.fc", "fc1.relu"],
                                            ["fc2.fc", "fc2.relu"]], inplace=True)
    
    net.eval()

    net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

    torch.quantization.prepare_qat(net, inplace=True)

    net.train()

    for epoch in range(1):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
                running_loss = 0.0
    
    print('Finished Training')

    net.to('cpu')

    net.eval()

    for layer in net.modules():
        if isinstance(layer, nn.Conv2d):
            weight = layer.weight
            scale = layer.weight_fake_quant.scale
            zero_point = layer.weight_fake_quant.zero_point
            x = torch.quantize_per_channel(weight, scale, zero_point, 0, torch.qint8)
            print('max: ', x.int_repr().max())
            print('min: ', x.int_repr().min())
            break