RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

I am running QAT yolov7. When I run with BN fusing before PTQ and QAT => it works on GPU. But when I comment out 2 lines (without BN fusing before PTQ and QAT) https://github.com/NVIDIA-AI-IOT/yolo_deepstream/blob/5af35bab7f6dfca7f1f32d44847b2a91786485f4/yolov7_qat/scripts/qat.py#L79 and i got an error when running on GPU (if I run on CPU it works).

Traceback (most recent call last):
  File "scripts/qat_BN.py", line 338, in <module>
    cmd_quantize(
  File "scripts/qat_BN.py", line 179, in cmd_quantize
    quantize.apply_custom_rules_to_quantizer(model, export_onnx)
  File "/yolov7_custom_dataset/quantization/quantize.py", line 222, in apply_custom_rules_to_quantizer
    export_onnx(model, "quantization-custom-rules-temp.onnx")
  File "scripts/qat_BN.py", line 138, in export_onnx
    quantize.export_onnx(model, dummy, file, opset_version=13, 
  File "/yolov7_custom_dataset/quantization/quantize.py", line 394, in export_onnx
    torch.onnx.export(model, input, file, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 506, in export
    _export(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1180, in _model_to_graph
    params_dict = _C._jit_pass_onnx_constant_fold(
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

I seem that this error is related to onnx conversion. I used Pytorch 2.0.1. I tried this thread [ONNX] Fix onnx constant folding by shingjan · Pull Request #101329 · pytorch/pytorch · GitHub but I could not file torch/csrc/jit/passes/onnx/constant_fold.cpp to modify it.

How can I do to get rid of the error? Thanks.