Creating a custom layer and using torch.qat for it

Hi, I have a specific case and need some help/pointer.

I have designed a specialized normalization layer (with learnable parms) from nn.Module and like to apply QAT to this one. But, I couldn’t find a doc on how to make a corresponding module for QAT: such as attaching processes weight_fake_quant, activation_post_process as well.

Any starter on this?

we have examples for qat.Conv2d here: pytorch/conv.py at master · pytorch/pytorch · GitHub

then you can add the mapping in:

or pass in a mapping that includes the new qat module in pytorch/quantize.py at master · pytorch/pytorch · GitHub

Hi, Jerry, thanks for sharing that. Yes, I also saw the mapping table, but (I should have been more clear on this): For a custom layer, I need to make a corresponding QAT version of that customer layer. Is there any particular requirement for such QAT versions (in doc or example)? such as a particular set of functions or attributes? One of them looks like ‘from_float’.

I figured it out. For anyone with the same interest, hope this template helpful. Jerry, please comment if anything missing here.

class qatCustom (nn.Module):
    def __init__(..., qconfig=None):
        super().__init__(....)
        self.qconifg = qconfig
        #skip this if there is no learnable parm
        self.weight_fake_quant = qconfig.weight()

    def forward(self, input):
        #implement the forward but with qweight
        qweight = self.weight_fake_quant(self.weight)
        #do compute
        return _forward_imple(input, qweight)

    @classmethod
    def from_float(cls, mod):
       qconfig = mod.qconfig
       qmod = cls(qmod constructor inputs, qconfig=qconfig)    
       qmod.weight = mod.weight
       return qmod
     
torch.quantization.quantization_mappings.register_qat_module_mapping(Custom, qatCustom) 
2 Likes

What should the qmod constructor inputs be? I defined a threshold comparison function for quantization after training, but encountered a lot of difficulties

class Threshold(nn.Module):
    def __init__(self, num_channels, qconfig=None):
        super().__init__()
        self.num_channels = num_channels
        self.threshold_plus = nn.Parameter(torch.ones(size=[num_channels]), requires_grad=False)
        self.threshold_minus = nn.Parameter(torch.ones(size=[num_channels]), requires_grad=False)
        self.threshold_plus_sign = nn.Parameter(torch.ones(size=[num_channels]), requires_grad=False)
        self.threshold_minus_sign = nn.Parameter(torch.ones(size=[num_channels]), requires_grad=False)
        # self.quant = torch.quantization.QuantStub()
        # self.dequant = torch.quantization.DeQuantStub()
        self.qconifg = qconfig
        self.weight_fake_quant = qconfig.weight()
    def forward(self, x):
        # x = self.quant(x)
        # assert x.size(1) == self.num_channels, f"Expected {self.num_channels} channels, but got {x.size(1)}"
        # Apply thresholding to each channel
        output = []
        for i in range(self.num_channels):
            channel = x[:, i, :]
            threshold_p = self.weight_fake_quant(self.threshold_plus[i])
            # threshold_p = self.threshold_plus[i]
            threshold_m = self.threshold_minus[i]
            p_sign = self.threshold_plus_sign[i]
            m_sign = self.threshold_minus_sign[i]
            channel_output = torch.where(channel >= 0, torch.where(channel > threshold_p, p_sign * torch.tensor(1.0),
                           p_sign * torch.tensor(-1.0)),
                            torch.where(channel > threshold_m, m_sign * torch.tensor(1.0),
                             m_sign * torch.tensor(-1.0)))
            output.append(channel_output)
        # Stack the thresholded channels along the channel dimension
        output = torch.stack(output, dim=1)
        # output = self.dequant(output)
        return output
    @classmethod
    def from_float(cls, mod):
        qconfig = mod.qconfig
        qmod = cls(qmod constructor inputs, qconfig = qconfig)
        qmod.weight = mod.weight
        return qmod