JIT Tracing fails with basic EmbeddingBag when changing devices

For saving and distributing models we use JIT and trace our models with example input. We use tracing instead of scripting because we use Huggingface Transformers, which recommends tracing.

We encountered issues at multiple points, usually in combination with reshaping. There was apparently a bug with this a couple years ago, but we could still reproduce it with a minimal EmbeddingBag example.

import tempfile

import torch
from torch import nn

folder = tempfile.mkdtemp()

my_model = nn.EmbeddingBag(5, 10).to('cuda:0')
my_input = torch.tensor([[1, 2], [1, 2]]).to('cuda:0')

my_model(my_input)

traced_model = torch.jit.trace(my_model, my_input)
traced_model.save(f'{folder}/model_jit.pt')

model_jit_loaded = torch.jit.optimize_for_inference(
    torch.jit.load(f'{folder}/model_jit.pt',
                   map_location='cpu').eval())

model_jit_loaded(my_input.to('cpu'))

This fails with the following error:

Traceback (most recent call last):
  File "/home/user/repos/team/repo/_scripts/trace_test.py", line 20, in <module>
    model_jit_loaded(my_input.to('cpu'))
  File "/home/user/repos/team/repo/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/torch/nn/modules/sparse.py", line 15, in forward
    offsets = torch.arange(0, _1, annotate(number, _2), dtype=4, layout=None, device=torch.device("cuda:0"), pin_memory=False)
    input0 = torch.reshape(input, [-1])
    _3, _4, _5, _6 = torch.embedding_bag(weight, input0, offsets, False, 1, False, None, False, None)
                     ~~~~~~~~~~~~~~~~~~~ <--- HERE
    return _3

Traceback of TorchScript, original code (most recent call last):
/home/user/repos/team/repo/.venv/lib/python3.9/site-packages/torch/nn/functional.py(2392): embedding_bag
/home/user/repos/team/repo/.venv/lib/python3.9/site-packages/torch/nn/modules/sparse.py(387): forward
/home/user/repos/team/repo/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py(1182): _slow_forward
/home/user/repos/team/repo/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py(1194): _call_impl
/home/user/repos/team/repo/.venv/lib/python3.9/site-packages/torch/jit/_trace.py(976): trace_module
/home/user/repos/team/repo/.venv/lib/python3.9/site-packages/torch/jit/_trace.py(759): trace
/home/user/repos/team/repo/_scripts/trace_test.py(13): <module>
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument offsets in method wrapper___embedding_bag_forward_only)

Is there a fix or are we doing something wrong in general?

1 Like

Based on your code it seems you are tracing the model on the GPU first and try to move it to the CPU later, which seems to cause the issue.
Would it be possible to trace and export the model on the same device?

1 Like

Yes, if we ran it on the same device the code would work. We want to be able to train our networks on GPUs and serve them on CPUs though. We thought that was possible with JIT trace and it’s a requirement for our orchestration.

For tracing, we thought it’s a bad practice to move tensors around on devices in forward steps, which is exactly why this built in model fails.

offsets = torch.arange(0, _1, annotate(number, _2), dtype=4, layout=None, 
                       device=torch.device("cuda:0"), pin_memory=False)
input0 = torch.reshape(input, [-1])

This code looks supsiciously like something that would cause this to happen if we did it in our forward step, but it happens in the builtin EmbeddingBag.

With Huggingface Transformers we can freely move onto different instances after tracing, and we thought this was the point of having traced models.

1 Like

Yes, I think you’ve narrowed down the offending line of code, which seems to bake in the used device in this line of code. The general recommendation was to use torch.jit.script as it would also allow to capture if conditions etc., but note that TorchScript is also in maintenance mode and won’t get any new features anymore.

2 Likes

Generally yes, but tracing without a baked in device would be preferable. We also have instances with multiple GPUs for inference and only having ‘cpu’ and ‘cuda:0’ would be limiting there.

but note that TorchScript is also in maintenance mode

Good to know, and as I’ve mentioned above, because we use the Transformers library, not really an option anyway.

Do I understand correctly, that exporting models without baking the device in is only possible in ONNX then?

Just as a quick info on why we would have liked to avoid going back to ONNX: We came from ONNX and moved to jit.trace because we had some problems with packed-padded-sequences in RNNs which were sorted unexpectedly during export and were destroying our internal structure. Plus jit could handle **kwargs or dictionaries in the forward step which is a very nice feature.

In any case thank you for the fast and helpful replies!

1 Like

No, torch.jit.script will also record the .device attribute (and other control-flow) and should work as seen here:

my_model = nn.EmbeddingBag(5, 10).to('cuda:0')
my_input = torch.tensor([[1, 2], [1, 2]]).to('cuda:0')

# works
out = my_model(my_input)

# trace
traced_model = torch.jit.trace(my_model, my_input)
# works
out = traced_model(my_input)
# breaks
traced_model.cpu()
traced_model(my_input.cpu())
# RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument offsets in method wrapper_CUDA___embedding_bag)

# script
my_model = nn.EmbeddingBag(5, 10).to('cuda:0')
my_input = torch.tensor([[1, 2], [1, 2]]).to('cuda:0')
scripted_model = torch.jit.script(my_model, my_input)
# works
out = scripted_model(my_input)
# works
scripted_model.cpu()
out = scripted_model(my_input.cpu())

However, I don’t know if your model requires trace and will fail otherwise.

1 Like

(in colab with then4p)

Yes it requires tracing, sadly on multiple occasions. The example we gave is just a very condensed part where we found the problem for this forum discussion. We include multiple external models (which can’t use scripting) and internal operations such as reshape and repeat (which can’t use tracing as apparent for EmbeddingBag).

Did we get that right, that the entire TorchScript is not in development any more, including scripting and tracing? So in the future there will be only ONNX for exporting and serializing models?

For the above reasons if this is the case we will probably sadly have to switch back to ONNX then.

1 Like

Yes, based on this message from a code owner.

I don’t know what the desired export path will be.

EDIT: here torch.export is mentioned for the 2.0 stack, which seems to be the planned approach.

1 Like