Error when wrting custom layer for quantization

Hi, I am trying to write a custom layer for a quantized int8 model. I did static eager mode quantization of a model and I am able to run it with a custom model, using layers from torch.nn.intrinsic.quantized. But when I try to use my own layer, I get an error saying:

NotImplementedError: Could not run ‘aten::empty.memory_format’ with arguments from the ‘QuantizedCPU’ backend. This could be because the operator doesn’t exist for this backend,

This happens when I try to initalize the weights in my layer. How should I initialize my weights to work with the ‘QuantizedCPU’ backend? Full code below:

import torch
import torch.quantization
import torch.nn as nn
import torch.nn.quantized as nnq
import torch.nn.intrinsic.quantized as nniq
from collections import OrderedDict

class custom_linear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, weight, bias):
        ctx.save_for_backward(X, weight, bias)
        return torch.addmm(bias, X, weight.transpose(0, 1))

    @staticmethod
    def backward(ctx, grad_output):
        X, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias


class QuantLinearReLU(nn.Module):
    def __init__(self, in_features, out_features):
        super(QuantLinearReLU, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = torch.nn.Parameter(torch.randn(out_features, in_features, dtype=torch.qint8))
        self.bias = torch.nn.Parameter(torch.randn(out_features))
        self.scale = torch.nn.Parameter(torch.randn(1))
        self.zero_point = torch.nn.Parameter(torch.randn(1))

    def forward(self, x):
        x = custom_linear.apply(x, self.weight, self.bias)
        return x


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.fc = nn.Linear(in_features=10, out_features=2)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.fc(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x


class custom_quant_model(nn.Module):
    def __init__(self):
        super(custom_quant_model, self).__init__()
        self.quant = nnq.Quantize(scale=1.0, zero_point=0, dtype=torch.quint8)
        # Model works fine with the nniq layer but not the custom one. 
        # self.fc = nniq.LinearReLU(in_features=10, out_features=2)
        self.fc = QuantLinearReLU(in_features=10, out_features=2)
        self.dequant = nnq.DeQuantize()

    def forward(self, x):
        x = self.quant(x)
        x = self.fc(x)
        x = self.dequant(x)
        return x


device = torch.device('cpu')
model = Net()
testInput = torch.rand(10, 10).cpu()

quant_net = Net()
quant_net.eval()
quant_net.qconfig = torch.quantization.get_default_qconfig('fbgemm')
quant_net_fused = torch.quantization.fuse_modules(quant_net, [['fc', 'relu']])
quant_net_prepared = torch.quantization.prepare(quant_net)

quant_net_prepared(testInput)
quant_int8 = torch.quantization.convert(quant_net_prepared)

quant_custom = custom_quant_model()

#Have to work with new ordered dicts to avoid access issues. 
state_dict_copy = quant_int8.state_dict()
new_state_dict = OrderedDict()
new_state_dict['quant.scale'] = state_dict_copy['quant.scale']
new_state_dict['quant.zero_point'] = state_dict_copy['quant.zero_point']
new_state_dict['fc.scale'] = state_dict_copy['fc.scale']
new_state_dict['fc.zero_point'] = state_dict_copy['fc.zero_point']
new_state_dict['fc.weight'], new_state_dict['fc.bias'] = state_dict_copy['fc._packed_params._packed_params']

print("\nNew state dict:\n")
print(new_state_dict)

quant_custom.load_state_dict(new_state_dict)

Thank you!

Update: I fixed this issue by changing

self.weight = torch.nn.Parameter(torch.randn(out_features, in_features, dtype=torch.qint8))

to

self.weight = torch._empty_affine_quantized([out_features, in_features], scale=1.0, zero_point=0, dtype=torch.qint8))

but now I am getting a new error that says:

RuntimeError: Only Tensors of floating point and complex dtype can require gradients

Currently quantized tensors are only supported during inference, there is no support for autograd. If you are interested in simulating quantization numerics during training, you could fake quantize your tensors using the torch.quantization.FakeQuantize module or the torch.quantize_per_tensor function. Would that help?

Hi @Vasiliy_Kuznetsov , thank you! So I ended up doing something similar, which is basically casting the quantized tensors to regular float32 tensors. I just need to pass them to my own cuda backend, so just before I do, I use quantize() and int_repr() to convert them to char for the GPU to process. But storing them as float32 tensors allows me to use the nn.functional functions on them, which is a big plus for me.

Also, I’m not sure if I will need the backward pass for what I am doing but if I do, would the regular backward pass (e.g., shown here) work for that? Since all my tensors are float32 anyway and I quantize them (using the scale and zero_point which I have saved as parameters), it should be able to do a backward pass the same as a normal nn.functional.linear function, am I right?