Error for using faster rcnn on c++ & GPU

I use pytorch to train a faster rcnn model to locate a hand in a image. After converting the model to torchscript by using tracing mothed. and I am very sure that the model and input are all placed on GPU, there is still always an erreor that “Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!”. And its log is listed as follow. and can anyone help me to solve this tough issue?

terminate called after throwing an instance of ‘std::runtime_error’
what(): The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
File “code/torch.py”, line 8, in
def forward(self: torch.TraceWrapper,
argument_1: Tensor) → Tuple[Tensor, Tensor, Tensor]:
_0, _1, _2, = (self.model).forward(argument_1, )

return (_0, _1, _2)
File "code/**torch**/torchvision/models/detection/faster_rcnn.py", line 18, in forward
s = ops.prim.NumToTensor(torch.size(img, 1))
s0 = ops.prim.NumToTensor(torch.size(img, 2))
_4, _5, _6, = (_3).forward(argument_1, )
~~~~~~~~~~~ <--- HERE
_7, _8, _9, _10, _11, _12, _13, _14, _15, _16, = (_2).forward(_4, )
_17 = (_1).forward(_7, _8, _9, _10, _11, _12, _13, _14, _15, _4, _5, _6, )
File "code/**torch**/torchvision/models/detection/transform.py", line 12, in forward
_0 = torch.slice(mean, 0, 0, 9223372036854775807, 1)
_1 = torch.unsqueeze(torch.unsqueeze(_0, 1), 2)
_2 = torch.sub(image, _1, alpha=1)
~~~~~~~~~ <--- HERE
_3 = torch.slice(std, 0, 0, 9223372036854775807, 1)
_4 = torch.unsqueeze(torch.unsqueeze(_3, 1), 2)

Traceback of TorchScript, original code (most recent call last):
/home/kevin/anaconda3/envs/script/lib/python3.6/site-packages/torchvision/models/detection/transform.py(124): normalize
/home/kevin/anaconda3/envs/script/lib/python3.6/site-packages/torchvision/models/detection/transform.py(104): forward
/home/kevin/anaconda3/envs/script/lib/python3.6/site-packages/torch/nn/modules/module.py(704): _slow_forward
/home/kevin/anaconda3/envs/script/lib/python3.6/site-packages/torch/nn/modules/module.py(720): _call_impl
/home/kevin/anaconda3/envs/script/lib/python3.6/site-packages/torchvision/models/detection/generalized_rcnn.py(79): forward
/home/kevin/anaconda3/envs/script/lib/python3.6/site-packages/torch/nn/modules/module.py(704): _slow_forward
/home/kevin/anaconda3/envs/script/lib/python3.6/site-packages/torch/nn/modules/module.py(720): _call_impl
faster_rcnn_script.py(34): forward
/home/kevin/anaconda3/envs/script/lib/python3.6/site-packages/torch/nn/modules/module.py(704): _slow_forward
/home/kevin/anaconda3/envs/script/lib/python3.6/site-packages/torch/nn/modules/module.py(720): _call_impl
/home/kevin/anaconda3/envs/script/lib/python3.6/site-packages/torch/jit/**init**.py(1109): trace_module
/home/kevin/anaconda3/envs/script/lib/python3.6/site-packages/torch/jit/**init**.py(955): trace
faster_rcnn_script.py(17): do_trace
faster_rcnn_script.py(64): save_jit_model
faster_rcnn_script.py(89): main
faster_rcnn_script.py(93):
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

my code is:
cv::Mat pred_img = m_predict_img.clone();
#ifdef RGB
pred_img.convertTo(*m_img_float, CV_32FC3, 1.0 / 255.0);
#else
pred_img.convertTo(*m_img_float, CV_32FC1, 1.0 / 255.0);
#endif

#ifdef TRACE
torch::Tensor tensor_image = torch::from_blob(m_img_float->data, {1, m_img_float->rows,
m_img_float->cols, m_img_float->channels()}, torch::kF32);
tensor_image = tensor_image.permute({0, 3, 1, 2}); //trace
#else
torch::Tensor tensor_image = torch::from_blob(m_img_float->data, {m_img_float->rows,
m_img_float->cols, m_img_float->channels()}, torch::kF32);
tensor_image = tensor_image.permute({2, 0, 1}); //script
#endif

```
auto img_var = torch::autograd::make_variable(tensor_image, false);

std::vector<torch::jit::IValue> inputs;
torch::jit::IValue output;
if (check_gpu_available()) {
```

#ifdef TRACE
inputs.push_back(img_var.to(torch::kCUDA));
//inputs.push_back(img_var.to(torch::kCPU));
#else
inputs.push_back(c10::Listtorch::Tensor(img_var.to(torch::kCUDA)));
//inputs.push_back(c10::Listat::Tensor(img_var.to(torch::kCPU)));
#endif
std::cout << "before prediction!" << std::endl;
output = m_crop_module.forward(inputs);
std::cout << "after prediction!" << std::endl;
} else {
inputs.push_back(c10::Listat::Tensor({img_var.to(torch::kCPU)}));
output = m_crop_module.forward(inputs);
}
inputs.pop_back();

```
auto out = output.toTuple()->elements();
```

My environment is:
ubuntu20.04+libtorch1.6.0+torchvision0.7.0

Could you post an executable code snippet using random inputs, which would reproduce this issue, please?
PS: you can post code snippets by wrapping them into three backticks ```, which would make debugging easier.

Thanks for your reply!

The code below is used to convert faster-rcnn model to Torchscript model. you can use the cmd:
python script.py --model-out ./ --trace trace
to reproduce the phenomenon. (You just need to copy the code to a file and to name the file script.py)

import torch
import torchvision
import numpy as np
import argparse
from torchvision import transforms
import cv2

def do_script(model):
    model_script = torch.jit.script(model)
    model_script.eval()
    return model_script

def do_trace(model, in_size=100):
    tmp = torch.rand(1, 3, in_size, in_size).to(next(model.parameters()).device)
    print('Trace data device: %s'%tmp.device)
    print('Trace model device: %s'%next(model.parameters()).device)
    model_trace = torch.jit.trace(model, tmp)
    model_trace.eval()
    return model_trace

def dict_to_tuple(out_dict):
    if "masks" in out_dict.keys():
        return (out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"])
    return (out_dict["boxes"], out_dict["scores"], out_dict["labels"])

class TraceWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, inp):
        out = self.model(inp)
        return dict_to_tuple(out[0])

def save_jit_model(args, script=True):
    model_funcs = [torchvision.models.detection.fasterrcnn_resnet50_fpn] 
    names = ["faster_rcnn"]

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    for name, model_func in zip(names, model_funcs):
        model = model_func(pretrained=True)
        model.to(device)
        #for nm, param in model.state_dict().items():
        #    print(nm, param.size(), param.device)
        #print('Loading model device: %s'%next(model.parameters()).device)

        in_size = 100
        in_p = torch.rand(1, 3, in_size, in_size)
        if script == False:
            model = TraceWrapper(model)
            inp = in_p.to(device)
        else:
            inp = [in_p[0].to(device)]
        
        model.eval()

        with torch.no_grad():
            out = model(inp)

            if script:
                out = dict_to_tuple(out[0])
                script_module = do_script(model)
                script_out = script_module([inp[0]])[1]
                script_out = dict_to_tuple(script_out[0])
            else:
                script_module = do_trace(model)
                script_out = script_module(inp)

            #assert len(out[0]) > 0 and len(script_out[0]) > 0

            torch._C._jit_pass_inline(script_module.graph)
            torch.jit.save(script_module, args.model_out + '/' + name + ".pt")

def main():
    parser = argparse.ArgumentParser(description='This app is used to convert faster-rcnn model to torchscript')
    parser.add_argument('--model-in', type=str, help='To specify the input model path')
    parser.add_argument('--model-out', type=str, help='To specify the output model path')
    parser.add_argument('--num-classes', type=int, help='To specify the number of classes')
    parser.add_argument('--trace', type=str, default='trace', help='To specify a image used to be tested')

    args = parser.parse_args()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    save_jit_model(args, args.trace == 'script')

if __name__ == '__main__':
    main()

hi ptrblck,
Is this issue in process? I am really keen on the result. can you give me more advice on it? thanks!