thyeros
February 4, 2021, 10:20pm
1
Hi, I have a specific case and need some help/pointer.
I have designed a specialized normalization layer (with learnable parms) from nn.Module and like to apply QAT to this one. But, I couldn’t find a doc on how to make a corresponding module for QAT: such as attaching processes weight_fake_quant, activation_post_process as well.
Any starter on this?
jerryzh168
(Jerry Zhang)
February 5, 2021, 6:35pm
2
we have examples for qat.Conv2d here: pytorch/conv.py at master · pytorch/pytorch · GitHub
then you can add the mapping in:
nniqat.ConvBnReLU1d: nniq.ConvReLU1d, nniqat.ConvBnReLU2d: nniq.ConvReLU2d, nniqat.ConvReLU2d: nniq.ConvReLU2d, nniqat.LinearReLU: nniq.LinearReLU, # QAT modules: nnqat.Linear: nnq.Linear, nnqat.Conv2d: nnq.Conv2d, } # Default map for swapping float module to qat modules DEFAULT_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = { nn.Conv2d: nnqat.Conv2d, nn.Linear: nnqat.Linear, nn.modules.linear._LinearWithBias: nnqat.Linear, # Intrinsic modules: nni.ConvBn1d: nniqat.ConvBn1d, nni.ConvBn2d: nniqat.ConvBn2d, nni.ConvBnReLU1d: nniqat.ConvBnReLU1d, nni.ConvBnReLU2d: nniqat.ConvBnReLU2d, nni.ConvReLU2d: nniqat.ConvReLU2d, nni.LinearReLU: nniqat.LinearReLU
or pass in a mapping that includes the new qat module in pytorch/quantize.py at master · pytorch/pytorch · GitHub
Hi, Jerry, thanks for sharing that. Yes, I also saw the mapping table, but (I should have been more clear on this): For a custom layer, I need to make a corresponding QAT version of that customer layer. Is there any particular requirement for such QAT versions (in doc or example)? such as a particular set of functions or attributes? One of them looks like ‘from_float’.
I figured it out. For anyone with the same interest, hope this template helpful. Jerry, please comment if anything missing here.
class qatCustom (nn.Module):
def __init__(..., qconfig=None):
super().__init__(....)
self.qconifg = qconfig
#skip this if there is no learnable parm
self.weight_fake_quant = qconfig.weight()
def forward(self, input):
#implement the forward but with qweight
qweight = self.weight_fake_quant(self.weight)
#do compute
return _forward_imple(input, qweight)
@classmethod
def from_float(cls, mod):
qconfig = mod.qconfig
qmod = cls(qmod constructor inputs, qconfig=qconfig)
qmod.weight = mod.weight
return qmod
torch.quantization.quantization_mappings.register_qat_module_mapping(Custom, qatCustom)
2 Likes
What should the qmod constructor inputs be? I defined a threshold comparison function for quantization after training, but encountered a lot of difficulties
thyeros:
I figured it out. For anyone with the same interest, hope this template helpful. Jerry, please comment if anything missing here.
class qatCustom (nn.Module):
def __init__(..., qconfig=None):
super().__init__(....)
self.qconifg = qconfig
#skip this if there is no learnable parm
self.weight_fake_quant = qconfig.weight()
def forward(self, input):
#implement the forward but with qweight
qweight = self.weight_fake_quant(self.weight)
#do compute
return _forward_imple(input, qweight)
@classmethod
def from_float(cls, mod):
qconfig = mod.qconfig
qmod = cls(qmod constructor inputs, qconfig=qconfig)
qmod.weight = mod.weight
return qmod
torch.quantization.quantization_mappings.register_qat_module_mapping(Custom, qatCustom)
class Threshold(nn.Module):
def __init__(self, num_channels, qconfig=None):
super().__init__()
self.num_channels = num_channels
self.threshold_plus = nn.Parameter(torch.ones(size=[num_channels]), requires_grad=False)
self.threshold_minus = nn.Parameter(torch.ones(size=[num_channels]), requires_grad=False)
self.threshold_plus_sign = nn.Parameter(torch.ones(size=[num_channels]), requires_grad=False)
self.threshold_minus_sign = nn.Parameter(torch.ones(size=[num_channels]), requires_grad=False)
# self.quant = torch.quantization.QuantStub()
# self.dequant = torch.quantization.DeQuantStub()
self.qconifg = qconfig
self.weight_fake_quant = qconfig.weight()
def forward(self, x):
# x = self.quant(x)
# assert x.size(1) == self.num_channels, f"Expected {self.num_channels} channels, but got {x.size(1)}"
# Apply thresholding to each channel
output = []
for i in range(self.num_channels):
channel = x[:, i, :]
threshold_p = self.weight_fake_quant(self.threshold_plus[i])
# threshold_p = self.threshold_plus[i]
threshold_m = self.threshold_minus[i]
p_sign = self.threshold_plus_sign[i]
m_sign = self.threshold_minus_sign[i]
channel_output = torch.where(channel >= 0, torch.where(channel > threshold_p, p_sign * torch.tensor(1.0),
p_sign * torch.tensor(-1.0)),
torch.where(channel > threshold_m, m_sign * torch.tensor(1.0),
m_sign * torch.tensor(-1.0)))
output.append(channel_output)
# Stack the thresholded channels along the channel dimension
output = torch.stack(output, dim=1)
# output = self.dequant(output)
return output
@classmethod
def from_float(cls, mod):
qconfig = mod.qconfig
qmod = cls(qmod constructor inputs, qconfig = qconfig)
qmod.weight = mod.weight
return qmod