Hello everyone, hope you are having a great time.
I have been playing with quantization recently, and during my last experiments, I tried to change something in my forward pass likethis :
def forward(self, x):
out1 = self.p1(x)
out2 = self.p2(out1)
out3 = self.p3(out2)
out1 = out1.mean(3).mean(2)
out2 = out2.mean(3).mean(2)
out3 = out3.mean(3).mean(2)
out5 = torch.cat((out1, out2, out3), 1)
output = self.classifier(out5)
return output
to
def forward(self, x):
out1 = self.p1(x)
out2 = self.p2(out1)
out3 = self.p3(out2)
out1 = F.avg_pool2d(out1, kernel_size=out1.size()[2:]).view(out1.size(0), -1)
out2 = F.avg_pool2d(out2, kernel_size=out2.size()[2:]).view(out2.size(0), -1)
out3 = F.avg_pool2d(out3, kernel_size=out3.size()[2:]).view(out3.size(0), -1)
out5 = torch.cat((out1, out2, out3), 1)
output = self.classifier(out5)
return output
I did this so I can easily experiment with qconfig, for example, prevent these operations from participating in the quantization process. something like this :
qconfig = get_default_qconfig('fbgemm')
qconfig_dict = {'': qconfig,
"object_type": [
(torch.nn.Linear, None),
(torch.nn.functional.avg_pool2d, None)
]
}
But upon trying to start the whole quantization process after this change I get the following error :
model = torch._fx.symbolic_trace(model)
File "C:\Users\User\Anaconda3\Lib\site-packages\torch\_fx\symbolic_trace.py", line 168, in symbolic_trace
return Tracer().trace(root)
File "C:\Users\User\Anaconda3\Lib\site-packages\torch\_fx\symbolic_trace.py", line 152, in trace
self.graph.output(self.create_arg(fn(*args)))
File "d:\Codes\fac_ver\python\Side_Projects\Facial-Landmark-PyTorch\models\slim.py", line 197, in forward
out1 = F.avg_pool2d(out1, kernel_size=out1.size()[2:]).view(out1.size(0), -1)
TypeError: avg_pool2d(): argument 'kernel_size' must be tuple of ints, not Proxy
and it seems out.size(2:]
is Proxy(getitem)
How should I get around this? I know I’m converting the model into symbolic_trace, but if I don’t, then I cant call the quantize_static_fx
which requires this.
How should I go about this?
pn: I know I can also use torch.mean
just fine, but I’d like to know how to get around that Proxy
thing in general.
Thanks a lot in advance