Unable to export model using ONNX when the model contains custom padding class

When I’m trying to export a model with ONNX which contains a custom padding class, it fails to export.
It shows the error as follows, but I’m unable to find the source of this problem. Can anyone please help me with the possible cause of such an error. Thanks :slight_smile:

(I’m removing some parts of the error due to size limitations)
Exception has occurred: RuntimeError
ONNX export failed: Couldn’t export Python operator ReflectionPadNd

Defined at:
/mnt/public/sarasaen/Code/FTSuperResDynMRI/utils/padding.py(38): forward
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/nn/modules/module.py(1003): _slow_forward
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/nn/modules/module.py(1015): _call_impl
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/nn/modules/container.py(139): forward
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/nn/modules/module.py(1003): _slow_forward
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/nn/modules/module.py(1015): _call_impl
/mnt/public/sarasaen/Code/FTSuperResDynMRI/models/ReconResNet.py(179): forwardV0
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/nn/modules/module.py(1003): _slow_forward
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/nn/modules/module.py(1015): _call_impl
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/jit/_trace.py(116): wrapper
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/jit/_trace.py(125): forward
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/nn/modules/module.py(1015): _call_impl
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/jit/_trace.py(1158): _get_trace_graph
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/onnx/utils.py(373): _trace_and_get_graph_from_model
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/onnx/utils.py(409): _create_jit_graph
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/onnx/utils.py(445): _model_to_graph
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/onnx/utils.py(676): _export
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/onnx/utils.py(88): export
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/site-packages/torch/onnx/init.py(271): export
/mnt/public/sarasaen/Code/FTSuperResDynMRI/rough.py(6):
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/runpy.py(87): _run_code
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/runpy.py(97): _run_module_code
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/runpy.py(265): run_path
/home/sarasaen/.vscode-server-insiders/extensions/ms-python.python-2021.3.680753044/pythonFiles/lib/python/debugpy/…/debugpy/server/cli.py(285): run_file
/home/sarasaen/.vscode-server-insiders/extensions/ms-python.python-2021.3.680753044/pythonFiles/lib/python/debugpy/…/debugpy/server/cli.py(444): main
/home/sarasaen/.vscode-server-insiders/extensions/ms-python.python-2021.3.680753044/pythonFiles/lib/python/debugpy/main.py(45):
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/runpy.py(87): _run_code
/mnt/public/sarasaen/bin/anaconda3/envs/torchMRIBeta/lib/python3.8/runpy.py(194): _run_module_as_main

Graph we tried to export:
graph(%LRCurrTP : Float(2, 1, 24, 24, 24, strides=[13824, 13824, 576, 24, 1], requires_grad=0, device=cpu),
%intialConv.1.weight : Float(64, 1, 7, 7, 7, strides=[343, 343, 49, 7, 1], requires_grad=1, device=cpu),
%intialConv.1.bias : Float(64, strides=[1], requires_grad=1, device=cpu),
%downsam.0.conv_block.0.weight : Float(128, 64, 3, 3, 3, strides=[1728, 27, 9, 3, 1], requires_grad=1, device=cpu),
%downsam.0.conv_block.0.bias : Float(128, strides=[1], requires_grad=1, device=cpu),
%downsam.1.conv_block.0.weight : Float(256, 128, 3, 3, 3, strides=[3456, 27, 9, 3, 1], requires_grad=1, device=cpu),
%downsam.1.conv_block.0.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.0.conv_block.1.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.0.conv_block.1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.0.conv_block.6.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.0.conv_block.6.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.1.conv_block.1.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.1.conv_block.1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.1.conv_block.6.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.1.conv_block.6.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.2.conv_block.1.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.2.conv_block.1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.2.conv_block.6.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.2.conv_block.6.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.3.conv_block.1.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.3.conv_block.1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.3.conv_block.6.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.3.conv_block.6.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.4.conv_block.1.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.4.conv_block.1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.4.conv_block.6.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.4.conv_block.6.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.5.conv_block.1.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.5.conv_block.1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.5.conv_block.6.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.5.conv_block.6.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.6.conv_block.1.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.6.conv_block.1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.6.conv_block.6.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.6.conv_block.6.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.7.conv_block.1.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.7.conv_block.1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.7.conv_block.6.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.7.conv_block.6.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.8.conv_block.1.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.8.conv_block.1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.8.conv_block.6.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.8.conv_block.6.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.9.conv_block.1.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.9.conv_block.1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.9.conv_block.6.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.9.conv_block.6.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.10.conv_block.1.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.10.conv_block.1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.10.conv_block.6.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.10.conv_block.6.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.11.conv_block.1.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.11.conv_block.1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.11.conv_block.6.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.11.conv_block.6.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.12.conv_block.1.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.12.conv_block.1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.12.conv_block.6.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.12.conv_block.6.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.13.conv_block.1.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.13.conv_block.1.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%resblocks.13.conv_block.6.weight : Float(256, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%resblocks.13.conv_block.6.bias : Float(256, strides=[1], requires_grad=1, device=cpu),
%upsam.0.conv_block.1.weight : Float(128, 256, 3, 3, 3, strides=[6912, 27, 9, 3, 1], requires_grad=1, device=cpu),
%upsam.0.conv_block.1.bias : Float(128, strides=[1], requires_grad=1, device=cpu),
%upsam.1.conv_block.1.weight : Float(64, 128, 3, 3, 3, strides=[3456, 27, 9, 3, 1], requires_grad=1, device=cpu),
%upsam.1.conv_block.1.bias : Float(64, strides=[1], requires_grad=1, device=cpu),
%finalconv.1.weight : Float(1, 64, 7, 7, 7, strides=[21952, 343, 49, 7, 1], requires_grad=1, device=cpu),
%finalconv.1.bias : Float(1, strides=[1], requires_grad=1, device=cpu),
%348 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%349 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%350 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%351 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%352 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%353 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%354 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%355 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%356 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%357 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%358 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%359 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%360 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%361 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%362 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%363 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%364 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%365 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu),
%366 : Float(1, 1, 1, 1, strides=[1, 1, 1, 1], requires_grad=0, device=cpu)):
%input.3 : Float(2, 1, 30, 30, 30, strides=[27000, 27000, 900, 30, 1], requires_grad=0, device=cpu) = ^ReflectionPadNd((3, 3, 3, 3, 3, 3))(%LRCurrTP) #