For saving and distributing models we use JIT and trace our models with example input. We use tracing instead of scripting because we use Huggingface Transformers, which recommends tracing.
We encountered issues at multiple points, usually in combination with reshaping. There was apparently a bug with this a couple years ago, but we could still reproduce it with a minimal EmbeddingBag example.
import tempfile
import torch
from torch import nn
folder = tempfile.mkdtemp()
my_model = nn.EmbeddingBag(5, 10).to('cuda:0')
my_input = torch.tensor([[1, 2], [1, 2]]).to('cuda:0')
my_model(my_input)
traced_model = torch.jit.trace(my_model, my_input)
traced_model.save(f'{folder}/model_jit.pt')
model_jit_loaded = torch.jit.optimize_for_inference(
torch.jit.load(f'{folder}/model_jit.pt',
map_location='cpu').eval())
model_jit_loaded(my_input.to('cpu'))
This fails with the following error:
Traceback (most recent call last):
File "/home/user/repos/team/repo/_scripts/trace_test.py", line 20, in <module>
model_jit_loaded(my_input.to('cpu'))
File "/home/user/repos/team/repo/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
File "code/__torch__/torch/nn/modules/sparse.py", line 15, in forward
offsets = torch.arange(0, _1, annotate(number, _2), dtype=4, layout=None, device=torch.device("cuda:0"), pin_memory=False)
input0 = torch.reshape(input, [-1])
_3, _4, _5, _6 = torch.embedding_bag(weight, input0, offsets, False, 1, False, None, False, None)
~~~~~~~~~~~~~~~~~~~ <--- HERE
return _3
Traceback of TorchScript, original code (most recent call last):
/home/user/repos/team/repo/.venv/lib/python3.9/site-packages/torch/nn/functional.py(2392): embedding_bag
/home/user/repos/team/repo/.venv/lib/python3.9/site-packages/torch/nn/modules/sparse.py(387): forward
/home/user/repos/team/repo/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py(1182): _slow_forward
/home/user/repos/team/repo/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py(1194): _call_impl
/home/user/repos/team/repo/.venv/lib/python3.9/site-packages/torch/jit/_trace.py(976): trace_module
/home/user/repos/team/repo/.venv/lib/python3.9/site-packages/torch/jit/_trace.py(759): trace
/home/user/repos/team/repo/_scripts/trace_test.py(13): <module>
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument offsets in method wrapper___embedding_bag_forward_only)
Is there a fix or are we doing something wrong in general?