Code
example = torch.rand(1, 3, 224, 224)
example = example.to('cuda')
print(f'modelx.device : {next(modelx.parameters()).device}' )
print(f'example.device : {example.device}')
model_trace = torch.jit.trace(modelx, example)
Output Trace
modelx.device : cuda:0
example.device : cuda:0
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-31-d2e0f4962a80> in <module>
4 print(f'modelx.device : {next(modelx.parameters()).device}' )
5 print(f'example.device : {example.device}')
----> 6 model_trace = torch.jit.trace(modelx, example)
~\Anaconda3\envs\tf_gpu\lib\site-packages\torch\jit\__init__.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
953 return trace_module(func, {'forward': example_inputs}, None,
954 check_trace, wrap_check_inputs(check_inputs),
--> 955 check_tolerance, strict, _force_outplace, _module_class)
956
957 if (hasattr(func, '__self__') and isinstance(func.__self__, torch.nn.Module) and
~\Anaconda3\envs\tf_gpu\lib\site-packages\torch\jit\__init__.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
1107 func = mod if method_name == "forward" else getattr(mod, method_name)
1108 example_inputs = make_tuple(example_inputs)
-> 1109 module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, strict, _force_outplace)
1110 check_trace_method = module._c._get_method(method_name)
1111
~\Anaconda3\envs\tf_gpu\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
718 input = result
719 if torch._C._get_tracing_state():
--> 720 result = self._slow_forward(*input, **kwargs)
721 else:
722 result = self.forward(*input, **kwargs)
~\Anaconda3\envs\tf_gpu\lib\site-packages\torch\nn\modules\module.py in _slow_forward(self, *input, **kwargs)
702 recording_scopes = False
703 try:
--> 704 result = self.forward(*input, **kwargs)
705 finally:
706 if recording_scopes:
~\Anaconda3\envs\tf_gpu\lib\site-packages\torchvision\models\detection\generalized_rcnn.py in forward(self, images, targets)
96 if isinstance(features, torch.Tensor):
97 features = OrderedDict([('0', features)])
---> 98 proposals, proposal_losses = self.rpn(images, features, targets)
99 detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
100 detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
~\Anaconda3\envs\tf_gpu\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
718 input = result
719 if torch._C._get_tracing_state():
--> 720 result = self._slow_forward(*input, **kwargs)
721 else:
722 result = self.forward(*input, **kwargs)
~\Anaconda3\envs\tf_gpu\lib\site-packages\torch\nn\modules\module.py in _slow_forward(self, *input, **kwargs)
702 recording_scopes = False
703 try:
--> 704 result = self.forward(*input, **kwargs)
705 finally:
706 if recording_scopes:
~\Anaconda3\envs\tf_gpu\lib\site-packages\torchvision\models\detection\rpn.py in forward(self, images, features, targets)
491 proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
492 proposals = proposals.view(num_images, -1, 4)
--> 493 boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
494
495 losses = {}
~\Anaconda3\envs\tf_gpu\lib\site-packages\torchvision\models\detection\rpn.py in filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level)
392
393 # select top_n boxes independently per level before applying nms
--> 394 top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
395
396 image_range = torch.arange(num_images, device=device)
~\Anaconda3\envs\tf_gpu\lib\site-packages\torchvision\models\detection\rpn.py in _get_top_n_idx(self, objectness, num_anchors_per_level)
372 pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors)
373 _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
--> 374 r.append(top_n_idx + offset)
375 offset += num_anchors
376 return torch.cat(r, dim=1)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Yes, I did perform the model.eval() just after loading the model