I’ve tried to implement hard sigmoid activation in a way suitable for quantization aware training:
from torch import nn
class HardSigmoid(nn.Module):
def __init__(self):
super().__init__()
self.act = nn.ReLU6()
self.add = nn.quantized.FloatFunctional()
self.mul = nn.quantized.FloatFunctional()
def forward(self, input):
# relu6(input + 3) / 6
output = self.add.add_scalar(input, 3)
output = self.act(output)
output = self.mul.mul_scalar(output, 1/6)
return output
The backward pass and conversion works fine:
import torch.quantization as tq
from torch.nn.intrinsic import ConvBn2d
import torch
model = nn.Sequential(
tq.QuantStub(),
ConvBn2d(
nn.Conv2d(3, 16, kernel_size=3, padding = 1, bias = False),
nn.BatchNorm2d(16)
),
HardSigmoid(),
tq.DeQuantStub()
)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.0001)
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model, inplace = True)
x = torch.rand(16, 3, 16, 16)
y = model(x)
y.sum().backward()
optimizer.step()
y = model(x)
model.apply(torch.quantization.disable_observer)
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
torch.quantization.convert(model.eval(), inplace = True)
But the forward pass (model(x)
) fails with the following error:
File "/usr/local/lib64/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/usr/local/lib64/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
input = module(input)
File "/usr/local/lib64/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "<stdin>", line 10, in forward
File "/usr/local/lib64/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/usr/local/lib64/python3.7/site-packages/torch/nn/modules/activation.py", line 209, in forward
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
File "/usr/local/lib64/python3.7/site-packages/torch/nn/functional.py", line 960, in hardtanh
result = torch._C._nn.hardtanh(input, min_val, max_val)
RuntimeError: Didn't find kernel to dispatch to for operator 'aten::hardtanh'. Tried to look up kernel for dispatch key 'QuantizedCPUTensorId'. Registered dispatch keys are: [CUDATensorId, CPUTensorId, VariableTensorId]
What is the correct way to implement hard sigmoid activation?