QConfig for Resnet50 with weights dtype quint8

Hello!
I am trying to perform quantization aware training on Resnet50 on imagenet, but differently from the default I want the weights to be unsigned (that is quint8, i think)

model = get_resnet50(pretrained = True, num_classes = 1000)
assert torch.cuda.is_available(), f'GPU not available'

model.to(cpu_device)
# Make a copy of the model for layer fusion
fused_model = deepcopy(model)

model.eval()

fused_model.eval()
device = torch.device("cuda")

 
fused_model = torch.quantization.fuse_modules(fused_model,
                                              [["conv1", "bn1", "relu"]],
                                              inplace=True)
for module_name, module in fused_model.named_children():
    if "layer" in module_name:
        for basic_block_name, basic_block in module.named_children():
            torch.quantization.fuse_modules(
                basic_block, [["conv1", "bn1", "relu1"], ["conv2", "bn2"]],
                inplace=True)
            for sub_block_name, sub_block in basic_block.named_children():
                if sub_block_name == "downsample":
                    torch.quantization.fuse_modules(sub_block,
                                                    [["0", "1"]],
                                                    inplace=True)
quantized_model = QuantizedResNet50(model_fp32=fused_model)
quantization_config = torch.quantization.QConfig(activation=torch.quantization.HistogramObserver.with_args(reduce_range=True), 
                                                 weight=torch.quantization.PerChannelMinMaxObserver.with_args(dtype=torch.quint8, qscheme=torch.per_channel_affine))
quantized_model.qconfig = quantization_config
torch.quantization.prepare_qat(quantized_model, inplace=True)

this is how I configure, then I launch the training but I get

AssertionError: Weight observer must have a dtype of qint8

Is there any way to do it? Am I missing something?

Thank you!

in your weight observer arguments try changing dtype=torch.quint8 to dtype=torch.qint8:

quantization_config = torch.quantization.QConfig(activation=torch.quantization.HistogramObserver.with_args(reduce_range=True),
weight=torch.quantization.PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_affine))

Yes but I want the weights to have dtype quint8, is it possible?

this is not supported right now because backends (fbgemm kernels) expects qint8 for weight

1 Like

so it doesn’t even support qint32 for weight quantization?