Creating a custom layer and using torch.qat for it

Hi, I have a specific case and need some help/pointer.

I have designed a specialized normalization layer (with learnable parms) from nn.Module and like to apply QAT to this one. But, I couldn’t find a doc on how to make a corresponding module for QAT: such as attaching processes weight_fake_quant, activation_post_process as well.

Any starter on this?

we have examples for qat.Conv2d here: pytorch/ at master · pytorch/pytorch · GitHub

then you can add the mapping in:

or pass in a mapping that includes the new qat module in pytorch/ at master · pytorch/pytorch · GitHub

Hi, Jerry, thanks for sharing that. Yes, I also saw the mapping table, but (I should have been more clear on this): For a custom layer, I need to make a corresponding QAT version of that customer layer. Is there any particular requirement for such QAT versions (in doc or example)? such as a particular set of functions or attributes? One of them looks like ‘from_float’.

I figured it out. For anyone with the same interest, hope this template helpful. Jerry, please comment if anything missing here.

class qatCustom (nn.Module):
    def __init__(..., qconfig=None):
        self.qconifg = qconfig
        #skip this if there is no learnable parm
        self.weight_fake_quant = qconfig.weight()

    def forward(self, input):
        #implement the forward but with qweight
        qweight = self.weight_fake_quant(self.weight)
        #do compute
        return _forward_imple(input, qweight)

    def from_float(cls, mod):
       qconfig = mod.qconfig
       qmod = cls(qmod constructor inputs, qconfig=qconfig)    
       qmod.weight = mod.weight
       return qmod
torch.quantization.quantization_mappings.register_qat_module_mapping(Custom, qatCustom) 
1 Like