Hello. I’m facing quite weird behaviour of torch. Here is my code:
image = ... # some numpy image
device = "cuda:0"
model = torch.jit.load("model.pt", map_location=device)
result = model(torch.from_numpy(image)[None].to(device))
This code throws an error
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-42-ea97506efa7e> in <module>
----> 1 result = model(torch.from_numpy(image)[None].to(device))
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
RuntimeError: MALFORMED INPUT: bad dtype in CompareSelect
However, if I write:
image = ... # some numpy image
model = torch.jit.load("model.pt", map_location="cuda:0")
result = model(torch.from_numpy(image)[None].to("cuda:0"))
It works fine.
My pytorch version is `‘1.9.0+cu102’
Any ideas, why it is happening?