How can I incorporate PReLU in a quantized model?

Thanks to dear God its done!
Here is the final solution!:

class PReLU_2(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.prelu_weight = prelu_object.weight
        self.weight = self.prelu_weight

    def forward(self, inputs):
        pos = torch.relu(inputs)
        neg = -self.weight * torch.relu(-inputs)
        inputs = pos + neg
        return inputs

and t his is the quantized version :

class PReLU_Quantized(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.prelu_weight = prelu_object.weight
        self.weight = self.prelu_weight
        self.quantized_op = nn.quantized.FloatFunctional()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, inputs):
        # inputs = max(0, inputs) + alpha * min(0, inputs) 
        self.weight = self.quant(self.weight)
        weight_min_res = self.quantized_op.mul(-self.weight, torch.relu(-inputs))
        inputs = self.quantized_op.add(torch.relu(inputs), weight_min_res)
        inputs = self.dequant(inputs)
        self.weight = self.dequant(self.weight)
        return inputs

2 Likes