Error when using string variable for device

Hello. I’m facing quite weird behaviour of torch. Here is my code:

image = ... # some numpy image
device = "cuda:0"
model = torch.jit.load("model.pt", map_location=device)
result = model(torch.from_numpy(image)[None].to(device))

This code throws an error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-42-ea97506efa7e> in <module>
----> 1 result = model(torch.from_numpy(image)[None].to(device))

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

RuntimeError: MALFORMED INPUT: bad dtype in CompareSelect

However, if I write:

image = ... # some numpy image
model = torch.jit.load("model.pt", map_location="cuda:0")
result = model(torch.from_numpy(image)[None].to("cuda:0"))

It works fine.

My pytorch version is `‘1.9.0+cu102’

Any ideas, why it is happening?

Could you post the model definition (in case that’s not possible, could you try to use a proxy model) to reproduce the issue? I also guess you’ve used torch.jit.trace to store the model?