Shisho_Sama
(A curious guy here!)
July 14, 2020, 2:40am
1
Hello everyone.
This is a followup question concerning this . The issue is in the Resnet model that I’m dealing with, I cant replace PReLU
with ReLU
as it drastically affects the network performance.
So my question is, what are my options here? what should I be doing in this case?
Would doing sth like this suffice?
class PReLU_Quantized(nn.Module):
def __init__(self, prelu_object):
super().__init__()
self.weight = prelu_object.weight
self.quantized_op = nn.quantized.FloatFunctional()
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, inputs):
# inputs = torch.max(0, inputs) + self.weight * torch.min(0, inputs)
self.weight = self.quant(self.weight)
weight_min_res = self.quantized_op.mul(self.weight, torch.min(inputs)[0])
inputs = self.quantized_op.add(torch.max(inputs)[0], weight_min_res).unsqueeze(0)
self.weight = self.dequant(self.weight)
return inputs
and for the replacement :
class model(nn.Module):
def __init__(self)
super().__init__()
....
self.prelu = PReLU()
self.prelu_q = PReLU_Quantized(self.prelu)
....
Thanks a lot in advance
Shisho_Sama
(A curious guy here!)
July 14, 2020, 4:17am
2
for some reason, the error between the actual PReLU and my implementation is very large!
here are sample diffs in different layers:
diff : 1.1562038660049438
diff : 0.02868632599711418
diff : 0.3653906583786011
diff : 1.6100226640701294
diff : 0.8999372720718384
diff : 0.03773299604654312
diff : -0.5090572834014893
diff : 0.1654307246208191
diff : 1.161868691444397
diff : 0.026089997962117195
diff : 0.4205571115016937
diff : 1.5337920188903809
diff : 0.8799554705619812
diff : 0.03827812895178795
diff : -0.40296515822410583
diff : 0.15618863701820374
and the diff is calculated like this in the forward pass:
def forward(self, x):
residual = x
out = self.bn0(x)
out = self.conv1(out)
out = self.bn1(out)
out = self.prelu(out)
out2 = self.prelu2(out)
print(f'diff : {( out - out2).mean().item()}')
out = self.conv2(out)
This is the normal implementation which I used on ordinary model (i.e. not quantized!) to assess whether it produces correct result and then move on to quantized version:
class PReLU_2(nn.Module):
def __init__(self, prelu_object):
super().__init__()
self.prelu_weight = prelu_object.weight
self.weight = self.prelu_weight
def forward(self, inputs):
x = self.weight
tmin, _ = torch.min(inputs,dim=0)
tmax, _ = torch.max(inputs,dim=0)
weight_min_res = torch.mul(x, tmin)
inputs = torch.add(tmax, weight_min_res)
inputs = inputs.unsqueeze(0)
return inputs
what am I missing here?
Shisho_Sama
(A curious guy here!)
July 14, 2020, 11:51am
3
OK, I figured it out! I made a huge mistake in the very begining. I needed to calculate
PReLU(x)=max(0,x)+a∗min(0,x)
or
and not the actual min! or max! which doesnt make sense!
now, can anyone do me a favor and tell me how I can vectorize this ? I’m kind of lost at the moment!
Shisho_Sama
(A curious guy here!)
July 14, 2020, 12:06pm
4
Thanks to dear God its done!
Here is the final solution!:
class PReLU_2(nn.Module):
def __init__(self, prelu_object):
super().__init__()
self.prelu_weight = prelu_object.weight
self.weight = self.prelu_weight
def forward(self, inputs):
pos = torch.relu(inputs)
neg = -self.weight * torch.relu(-inputs)
inputs = pos + neg
return inputs
and t his is the quantized version :
class PReLU_Quantized(nn.Module):
def __init__(self, prelu_object):
super().__init__()
self.prelu_weight = prelu_object.weight
self.weight = self.prelu_weight
self.quantized_op = nn.quantized.FloatFunctional()
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, inputs):
# inputs = max(0, inputs) + alpha * min(0, inputs)
self.weight = self.quant(self.weight)
weight_min_res = self.quantized_op.mul(-self.weight, torch.relu(-inputs))
inputs = self.quantized_op.add(torch.relu(inputs), weight_min_res)
inputs = self.dequant(inputs)
self.weight = self.dequant(self.weight)
return inputs
2 Likes
viash
(Vishal Shitole)
May 26, 2022, 9:12am
5
Hi, Thanks for the solution on PRELU Quantized. Any Tips on fusing the PRELU with CONV
CONVPRELU , CONVBNPRELU.