Good catch.
Its bigger than view though, pretty much every quant pattern in GeneralTensorShapeOpQuantizeHandler has the same issue if you do anything but hard code the non tensor arguments. To be honest I’m not sure why these need a quant handler since they can handle both normal and qtensors, they don’t break anything if they are excluded from the flow.
e.g.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization.quantize_fx import prepare_fx, convert_fx
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.lin = nn.Linear(5,1)
def forward(self, x, y):
x = self.pool(F.relu(self.conv1(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = x.view(-1, y)
x = self.lin(x)
return x
model=Net().eval()
model(torch.randn(5,3,32,32), 5)
qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
# qconfig_dict = {"": qconfig, "object_type": [('view', None)]}
prepared_model = prepare_fx(model, qconfig_dict)
print(prepared_model.code)
prepared_model(torch.randn(5,3,32,32), 5)
final_model = convert_fx(prepared_model)
print(final_model.code)
final_model(torch.randn(5,3,32,32), 5)
@vlc
you can solve the issue by specifying None as the qconfig for view (see the commented out qconfig dict in the repro) to exclude it.