Hello, I am trying to export the Faster RCNN model from PyTorch after performing quantization on the backbone:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
example_input = torch.randn(1, 3, 224, 224)
model.backbone = quantize_fx.fuse_fx(model.backbone) # The modules of the model are fused
model.train() # The model needs to be set to train before preparation for QAT
qconfig_mapping = get_default_qconfig_mapping("fbgemm")
model.backbone = quantize_fx.prepare_qat_fx(model.backbone, qconfig_mapping, example_input)
model.backbone = quantize_fx.convert_fx(model.backbone)
When I run the onnx export, I get the error:
RuntimeError: Type 'Tuple[Tensor, NoneType]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced
Meanwhile, I do not get this error when I am exporting the un-quantized model. Thank you very much for your help.
I got the same error and figured out the cause. Not sure if your error has the same cause as mine, but even so this explanation will likely help someone, as it took me hours to figure it out.
RuntimeError: Type 'Tuple[Tensor, NoneType]' cannot be traced ... will happen if you try to export a quantized model, and the model’s
forward method has a second argument with a default value of
torch.onnx.export function first puts your Tensor example input in a tuple. Then it inspects the signature of the model
forward method, and if it finds the method takes a second argument with a default value of
None, it automatically adds a second “input” to the tuple with the value None.
Next it sees that you are exporting a quantized model, and calls the function
_pre_trace_quant_model which simply calls
torch.jit.trace(model, args). The
args is the tuple containing
(example_input, None) at this point. The error message you see is from
torch.jit.trace complaining that it does not support
None as an argument value.
In my case, I resolved the error by modifying the
forward method of my model to remove the second argument that had the default value of
Hope this helps someone!
Edit: I used PyTorch 2.0.0