class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
def forward(self, x0, x1, x2):
y = torch.cat([x0, x1, x2], 3)
return y
from torch.quantization import get_default_qconfig
from torch.quantization.quantize_fx import prepare_fx, convert_fx
fp32_input0 = torch.randn(1, 5, 5, 128)
fp32_input1 = torch.randn(1, 5, 5, 64)
fp32_input2 = torch.randn(1, 5, 5, 32)
model = Net()
model.eval()
y = model(fp32_input0, fp32_input1, fp32_input2)
qconfig = get_default_qconfig("qnnpack")
qconfig_dict = {"": qconfig}
model_prepared = prepare_fx(model, qconfig_dict)
model_int8 = convert_fx(model_prepared)
torch.jit.save(torch.jit.trace(model_int8, (fp32_input0, fp32_input1, fp32_input2)), "quantized_model.pt")
How to make quantized::cat appear in quantized network using FX?