Is it possible to do int16 qat?

Hi, all:

I am trying to quantize a siam-style tracking model, and unfortunately I found that with int8 qat, the resulted model didn’t work at all. I am not sure whether is my qat setup goes wrong or int8 is not enough for such task.

Since I only want a quantized backbone, the qat setup is like:

1.replace skip-connection "+" with nn.quantized.FloatFunctional()
2.coding fuse model method for structure like conv-bn-relu, then:
    if cfg.TRAIN.QAT:
        model.eval()
        model.backbone.fuse_model()
        model.backbone.qconfig = quantization.QConfig(weight=wt_fake_quant_per_channel_8bit, activation=act_fake_quant_8bit,)
        model.backbone = quantization.prepare_qat(model.backbone, inplace=True)
        if cfg.ADJUST.ADJUST:
            model.neck.apply(quantization.disable_fake_quant)
        if cfg.BAN.BAN:
            model.ban_head.apply(quantization.disable_fake_quant)
        model.backbone.features[-1].skip_add.activation_post_process.apply(quantization.disable_fake_quant) 
    model.train()
    # start training
    train(dist_model, optimizer, lr_scheduler, tb_writer)

Is the code above has something wrong? I disable the last activation quantization because the followed layers are working on float32, is this step necessary?

If such code looks good, I guess it is just int8 is not enough, So, is it possible to do int 16 qat, I try something like:

act_fake_quant_16bit = quantization.FakeQuantize.with_args(
    observer=MinMaxObserver_Clip6.with_args(
        quant_min=-32767,
        quant_max=32767,
        dtype=torch.qint32, 
        qscheme=torch.per_tensor_symmetric,  
        reduce_range=False
    ),
    quant_min=-32767,
    quant_max=32767,
    dtype=torch.qint32,
    qscheme=torch.per_tensor_symmetric,
    reduce_range=False
)

wt_fake_quant_per_channel_16bit = quantization.FakeQuantize.with_args(
    observer=quantization.PerChannelMinMaxObserver.with_args(
        quant_min=-32767,
        quant_max=32767,
        dtype=torch.qint32,
        qscheme=torch.per_channel_symmetric,
        reduce_range=False
    ),
    quant_min=-32767,
    quant_max=32767,
    dtype=torch.qint32,
    qscheme=torch.per_channel_symmetric,
    reduce_range=False
)

but get errors from torch says that only int8/uint8 dtype is acceptable.

the way you are using it is fine, but eager mode quantization API does not have good customization support. maybe you can try our new flow? Quantization — PyTorch main documentation