Thanks to dear God its done!
Here is the final solution!:
class PReLU_2(nn.Module):
def __init__(self, prelu_object):
super().__init__()
self.prelu_weight = prelu_object.weight
self.weight = self.prelu_weight
def forward(self, inputs):
pos = torch.relu(inputs)
neg = -self.weight * torch.relu(-inputs)
inputs = pos + neg
return inputs
and t his is the quantized version :
class PReLU_Quantized(nn.Module):
def __init__(self, prelu_object):
super().__init__()
self.prelu_weight = prelu_object.weight
self.weight = self.prelu_weight
self.quantized_op = nn.quantized.FloatFunctional()
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, inputs):
# inputs = max(0, inputs) + alpha * min(0, inputs)
self.weight = self.quant(self.weight)
weight_min_res = self.quantized_op.mul(-self.weight, torch.relu(-inputs))
inputs = self.quantized_op.add(torch.relu(inputs), weight_min_res)
inputs = self.dequant(inputs)
self.weight = self.dequant(self.weight)
return inputs