Hi all, I have a model which contains an BiLSTM that helps generating full context embeddings for a list of images.
The model works fine with AMP, training, inference … all good. Except now I’d like to deploy it using AWS Elastic inference which requires a JIT traced model.
When I run the export I receive the following:
/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/jit/_trace.py:966: TracerWarning: Output nr 2. of the traced function does not match the corresponding output of the Python function. Detailed error:
With rtol=1e-05 and atol=1e-05, found 1 element(s) (out of 1) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.009498357772827148 (1.0842735767364502 vs. 1.0937719345092773), which occurred at index 0.
_module_class,
/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/jit/_trace.py:966: TracerWarning: Output nr 3. of the traced function does not match the corresponding output of the Python function. Detailed error:
With rtol=1e-05 and atol=1e-05, found 15 element(s) (out of 15) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.026749849319458008 (0.37732791900634766 vs. 0.35057806968688965), which occurred at index (4, 2).
_module_class,
Traceback (most recent call last):
File "/home/davide/.pyenv/versions/3.7.8/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/home/davide/.pyenv/versions/3.7.8/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/davide/Desktop/coefficient/native_oneshot/native_oneshot/scripts/export-to-traced-model.py", line 57, in <module>
model, (support_x, support_y_onehot, target_x, target_y)
File "/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/jit/_trace.py", line 742, in trace
_module_class,
File "/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/jit/_trace.py", line 966, in trace_module
_module_class,
File "/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
return func(*args, **kwargs)
File "/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/jit/_trace.py", line 519, in _check_trace
raise TracingCheckError(*diag_info)
torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
ERROR: Tensor-valued Constant nodes differed in value across invocations. This often indicates that the tracer has encountered untraceable code.
Node:
%hx : Tensor = prim::Constant[value=<Tensor>](), scope: __module.lstm # /home/davide/Desktop/coefficient/native_oneshot/native_oneshot/matching_network/model/lstm.py:50:0
Source Location:
/home/davide/Desktop/coefficient/native_oneshot/native_oneshot/matching_network/model/lstm.py(50): repackage_hidden
/home/davide/Desktop/coefficient/native_oneshot/native_oneshot/matching_network/model/lstm.py(52): <genexpr>
/home/davide/Desktop/coefficient/native_oneshot/native_oneshot/matching_network/model/lstm.py(52): repackage_hidden
/home/davide/Desktop/coefficient/native_oneshot/native_oneshot/matching_network/model/lstm.py(56): forward
/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/nn/modules/module.py(726): _slow_forward
/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/nn/modules/module.py(742): _call_impl
/home/davide/Desktop/coefficient/native_oneshot/native_oneshot/matching_network/model/matching.py(70): forward
/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/nn/modules/module.py(726): _slow_forward
/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/nn/modules/module.py(742): _call_impl
/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/jit/_trace.py(940): trace_module
/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/jit/_trace.py(742): trace
/home/davide/Desktop/coefficient/native_oneshot/native_oneshot/scripts/export-to-traced-model.py(57): <module>
/home/davide/.pyenv/versions/3.7.8/lib/python3.7/runpy.py(85): _run_code
/home/davide/.pyenv/versions/3.7.8/lib/python3.7/runpy.py(193): _run_module_as_main
Comparison exception: With rtol=0.0001 and atol=1e-05, found 90 element(s) (out of 96) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 11.000000476837158 (-0.947547435760498 vs. -11.947547912597656), which occurred at index (0, 0, 8).
Which says that there is some untraceable code, pointing at the repackage_hidden method of my LSTM. Here is my LSTM module:
from __future__ import annotations
import torch
import torch.nn as nn
from torch.autograd import Variable
class BidirectionalLSTM(nn.Module):
def __init__(self, layer_size, vector_dim, device):
super().__init__()
"""
Initial a muti-layer Bidirectional LSTM
:param layer_size: a list of each layer'size
:param batch_size:
:param vector_dim:
"""
self.batch_size = 1
self.hidden_size = layer_size[0]
self.vector_dim = vector_dim
self.num_layer = len(layer_size)
self.lstm = nn.LSTM(
input_size=self.vector_dim,
num_layers=self.num_layer,
hidden_size=self.hidden_size,
bidirectional=True,
)
self.hidden = (
Variable(
torch.zeros(
self.lstm.num_layers * 2,
self.batch_size,
self.lstm.hidden_size,
),
requires_grad=False,
).to(device),
Variable(
torch.zeros(
self.lstm.num_layers * 2,
self.batch_size,
self.lstm.hidden_size,
),
requires_grad=False,
).to(device),
)
def repackage_hidden(self, h):
"""Wraps hidden states in new Variables,
to detach them from their history."""
if type(h) == torch.Tensor:
return Variable(h.data)
else:
return tuple(self.repackage_hidden(v) for v in h)
def forward(self, inputs):
inputs = inputs.float()
self.hidden = self.repackage_hidden(self.hidden)
output, self.hidden = self.lstm(inputs, self.hidden)
return output
and my export code:
with torch.jit.optimized_execution(True) and torch.no_grad():
traced_model = torch.jit.trace(
model, (support_x, support_y_onehot, target_x, target_y)
)
I assume that the if else control flow in repackage_hidden might be part of the problem? How to bypass / solve the issue?