Torch jit script and FloatFunctional error

Hello, I am trying to use the packages of quantization that PyTorch provides to quantize a MobileNet 3D. I followed the steps recommended in https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html. Everything goes fine until I try to compile the model as a TorchScript with torch.jit.script for saving the model. The "Replacing addition with nn.quantized.FloatFunctional" step yields the following error:

Traceback (most recent call last):
File “main.py”, line 155, in
main()
File “main.py”, line 146, in main
torch.jit.save(torch.jit.script(fp_model), (opt.model_path / ‘quantized’ / ‘quant_mobilenet3d.pth’).as_posix())
File “/home/ctm/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/jit/init.py”, line 1516, in script
return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
File “/home/ctm/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 318, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File “/home/ctm/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 372, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File “/home/ctm/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/jit/init.py”, line 1900, in _construct
init_fn(script_module)
File “/home/ctm/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 353, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
File “/home/ctm/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 372, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File “/home/ctm/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/jit/init.py”, line 1900, in _construct
init_fn(script_module)
File “/home/ctm/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 353, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
File “/home/ctm/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 376, in create_script_module_impl
create_methods_from_stubs(concrete_type, stubs)
File “/home/ctm/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/jit/_recursive.py”, line 292, in create_methods_from_stubs
concrete_type._create_methods(defs, rcbs, defaults)
File “/home/ctm/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/jit/init.py”, line 1359, in _recursive_compile_class
_compile_and_register_class(obj, rcb, _qual_name)
File “/home/ctm/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/jit/init.py”, line 1363, in _compile_and_register_class
_jit_script_class_compile(qualified_name, ast, rcb)
RuntimeError:
undefined value super:
File “/home/ctm/.conda/envs/3dcnn/lib/python3.7/site-packages/torch/nn/quantized/modules/functional_modules.py”, line 35
def init(self):
super(FloatFunctional, self).init()
~~~~~ <— HERE
self.activation_post_process = torch.nn.Identity()
‘FloatFunctional.init’ is being compiled since it was called from ‘torch.torch.nn.quantized.modules.functional_modules.FloatFunctional’
File “/home/ctm/afonso/easyride/acceleration/src/models/mobilenetv2.py”, line 69
if self.use_res_connect:
if self.quantize:
return nn.quantized.FloatFunctional().add(x, self.conv(x))
~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <— HERE
else:
return x + self.conv(x)
torch.torch.nn.quantized.modules.functional_modules.FloatFunctional’ is being compiled since it was called from ‘InvertedResidual.forward’
File “/home/ctm/afonso/easyride/acceleration/src/models/mobilenetv2.py”, line 69
if self.use_res_connect:
if self.quantize:
return nn.quantized.FloatFunctional().add(x, self.conv(x))
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <— HERE
else:
return x + self.conv(x)

Not using the FloatFunctional function and keeping the addition as is enables the model to be saved and loaded correctly, but will later give errors on the inference step because of the operation not being supported by the QuantizedCPU backend. Similarly, using the addition from QFunctional also gives backend-based errors.

I found the answer in another thread, but cannot find it again. Basically, the nn.quantized.FloatFunctional() class should be initialized in the init function so that it is visible to the TorchScript inspector.
This will not work:

def forward(self, x):
   return nn.quantized.FloatFunctional().add(x, self.conv(x))

Yet something like this will:

def __init__():
   self.q_add = nn.quantized.FloatFunctional()
...
def forward(self, x):
   return self.q_add.add(x, self.conv(x))