Pth to onnx conversion error

Converting GFPGANv1.pth original model to onnx gave me the following errors.

Traceback (most recent call last):
  File "/home/batman/GFPGAN-Training-Models-To-Onnx/new_from_github.py", line 45, in <module>
    torch.onnx.export(inference_model,  # model being run
  File "/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1596, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1139, in _model_to_graph
    graph = _optimize_graph(
  File "/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/onnx/utils.py", line 677, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1940, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 306, in wrapper
    return fn(g, *args, **kwargs)
  File "/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py", line 2519, in _convolution
    raise errors.SymbolicValueError(
torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of convolution for kernel of unknown shape.  [Caused by the value '818 defined in (%818 : Float(*, *, *, *, strides=[8192, 16, 4, 1], requires_grad=1, device=cuda:0) = onnx::Reshape[allowzero=0](%770, %817), scope: gfpgan.archs.gfpganv1_arch.GFPGANv1::/gfpgan.archs.gfpganv1_arch.StyleGAN2GeneratorSFT::stylegan_decoder/basicsr.archs.stylegan2_arch.StyleConv::style_conv1/basicsr.archs.stylegan2_arch.ModulatedConv2d::modulated_conv # /home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/basicsr/archs/stylegan2_arch.py:282:0
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::Reshape'.]
    (node defined in /home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/basicsr/archs/stylegan2_arch.py(282): forward
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1527): _call_impl
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/basicsr/archs/stylegan2_arch.py(333): forward
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1527): _call_impl
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/home/batman/GFPGAN-Training-Models-To-Onnx/gfpgan/archs/gfpganv1_arch.py(102): forward
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1527): _call_impl
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/home/batman/GFPGAN-Training-Models-To-Onnx/gfpgan/archs/gfpganv1_arch.py(395): forward
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1527): _call_impl
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/jit/_trace.py(124): wrapper
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/jit/_trace.py(133): forward
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1527): _call_impl
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/jit/_trace.py(1285): _get_trace_graph
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/onnx/utils.py(915): _trace_and_get_graph_from_model
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/onnx/utils.py(1011): _create_jit_graph
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/onnx/utils.py(1135): _model_to_graph
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/onnx/utils.py(1596): _export
/home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/torch/onnx/utils.py(516): export
/home/batman/GFPGAN-Training-Models-To-Onnx/new_from_github.py(45): <module>
)

    Inputs:
        #0: 770 defined in (%770 : Float(*, *, *, *, strides=[8192, 16, 4, 1], requires_grad=1, device=cuda:0) = onnx::Tile(%769, %766), scope: gfpgan.archs.gfpganv1_arch.GFPGANv1::/gfpgan.archs.gfpganv1_arch.StyleGAN2GeneratorSFT::stylegan_decoder/basicsr.archs.stylegan2_arch.ConstantInput::constant_input # /home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/basicsr/archs/stylegan2_arch.py:399:0
    )  (type 'Tensor')
        #1: 817 defined in (%817 : int[] = prim::ListConstruct(%295, %816, %785, %790), scope: gfpgan.archs.gfpganv1_arch.GFPGANv1::/gfpgan.archs.gfpganv1_arch.StyleGAN2GeneratorSFT::stylegan_decoder/basicsr.archs.stylegan2_arch.StyleConv::style_conv1/basicsr.archs.stylegan2_arch.ModulatedConv2d::modulated_conv
    )  (type 'List[int]')
    Outputs:
        #0: 818 defined in (%818 : Float(*, *, *, *, strides=[8192, 16, 4, 1], requires_grad=1, device=cuda:0) = onnx::Reshape[allowzero=0](%770, %817), scope: gfpgan.archs.gfpganv1_arch.GFPGANv1::/gfpgan.archs.gfpganv1_arch.StyleGAN2GeneratorSFT::stylegan_decoder/basicsr.archs.stylegan2_arch.StyleConv::style_conv1/basicsr.archs.stylegan2_arch.ModulatedConv2d::modulated_conv # /home/batman/GFPGAN-Training-Models-To-Onnx/onnxVenv/lib/python3.10/site-packages/basicsr/archs/stylegan2_arch.py:282:0
    )  (type 'Tensor')

Conversion code for converting GFPGANv1.pth to onnx.

import cv2
from basicsr.utils import img2tensor
from torchvision.transforms.functional import normalize
import torch

from gfpgan.archs.gfpganv1_arch import GFPGANv1

model_path = "./GFPGANv1.pth"
onnx_path = "./experiments/GFPGAN_v1.onnx"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("-------------------", device)

inference_model = GFPGANv1(
        out_size=512,
        num_style_feat=512,
        channel_multiplier=1,
        decoder_load_path=None,
        fix_decoder=False,
        num_mlp=8,
        input_is_latent=True,
        different_w=True,
        narrow=1,
        sft_half=True).to(device)

loadnet = torch.load(model_path)
if 'params_ema' in loadnet:
    keyname = 'params_ema'
else:
    keyname = 'params'
inference_model.load_state_dict(loadnet[keyname], strict=False)
inference_model = inference_model.eval()
img_path = './1.png'
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = cv2.resize(input_img, (512, 512))
cropped_face_t = img2tensor(img / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)

mat1 = torch.randn(3, 512, 512).to(device)  # moving the tensor to cpu
mat1 = mat1.unsqueeze(0).to(device)

#torch_out = torch.jit.trace(inference_model, cropped_face_t)

torch.onnx.export(inference_model,  # model being run
                    mat1,  # model input (or a tuple for multiple inputs)
                    onnx_path,  # where to save the model (can be a file or file-like object)
                    export_params=True,  # store the trained parameter weights inside the model file
                    autograd_inlining=False,
                    opset_version=16,  # the ONNX version to export the model to
                    do_constant_folding=True,  # whether to execute constant folding for optimization
                    verbose=True,
                    input_names = ['modelInput'],   # the model's input names 
                    output_names = ['modelOutput'], # the model's output names 
                    dynamic_axes={'modelInput' : {0 : 'batch_size'},    # variable length axes 
                                'modelOutput' : {0 : 'batch_size'}}
                    )

print("export GFPGANv1 onnx done.")

I already tried with different Opset versions and different Pytorch versions but same error occurs and only Pytorch 2.1.0 gives me the error details.

Versions

Collecting environment information...
PyTorch version: 2.1.0+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.90.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX A2000 12GB
Nvidia driver version: 536.67
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 8
On-line CPU(s) list: 0-7
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) W-2223 CPU @ 3.60GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 4
Socket(s): 1
Stepping: 7
BogoMIPS: 7199.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_vnni flush_l1d arch_capabilities
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 128 KiB (4 instances)
L1i cache: 128 KiB (4 instances)
L2 cache: 4 MiB (4 instances)
L3 cache: 8.3 MiB (1 instance)
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] numpy==1.26.2
[pip3] onnx==1.15.0
[pip3] onnxruntime==1.16.2
[pip3] onnxsim==0.4.35
[pip3] torch==2.1.0+cu118
[pip3] torchaudio==2.1.0+cu118
[pip3] torchvision==0.16.0+cu118
[pip3] triton==2.1.0
[conda] Could not collect