LSTM repackage_hidden probblem with JIT trace?

Hi all, I have a model which contains an BiLSTM that helps generating full context embeddings for a list of images.

The model works fine with AMP, training, inference … all good. Except now I’d like to deploy it using AWS Elastic inference which requires a JIT traced model.

When I run the export I receive the following:

/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/jit/_trace.py:966: TracerWarning: Output nr 2. of the traced function does not match the corresponding output of the Python function. Detailed error:
With rtol=1e-05 and atol=1e-05, found 1 element(s) (out of 1) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.009498357772827148 (1.0842735767364502 vs. 1.0937719345092773), which occurred at index 0.
  _module_class,
/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/jit/_trace.py:966: TracerWarning: Output nr 3. of the traced function does not match the corresponding output of the Python function. Detailed error:
With rtol=1e-05 and atol=1e-05, found 15 element(s) (out of 15) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.026749849319458008 (0.37732791900634766 vs. 0.35057806968688965), which occurred at index (4, 2).
  _module_class,
Traceback (most recent call last):
  File "/home/davide/.pyenv/versions/3.7.8/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/davide/.pyenv/versions/3.7.8/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/davide/Desktop/coefficient/native_oneshot/native_oneshot/scripts/export-to-traced-model.py", line 57, in <module>
    model, (support_x, support_y_onehot, target_x, target_y)
  File "/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/jit/_trace.py", line 742, in trace
    _module_class,
  File "/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/jit/_trace.py", line 966, in trace_module
    _module_class,
  File "/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/jit/_trace.py", line 519, in _check_trace
    raise TracingCheckError(*diag_info)
torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
ERROR: Tensor-valued Constant nodes differed in value across invocations. This often indicates that the tracer has encountered untraceable code.
        Node:
                %hx : Tensor = prim::Constant[value=<Tensor>](), scope: __module.lstm # /home/davide/Desktop/coefficient/native_oneshot/native_oneshot/matching_network/model/lstm.py:50:0
        Source Location:
                /home/davide/Desktop/coefficient/native_oneshot/native_oneshot/matching_network/model/lstm.py(50): repackage_hidden
                /home/davide/Desktop/coefficient/native_oneshot/native_oneshot/matching_network/model/lstm.py(52): <genexpr>
                /home/davide/Desktop/coefficient/native_oneshot/native_oneshot/matching_network/model/lstm.py(52): repackage_hidden
                /home/davide/Desktop/coefficient/native_oneshot/native_oneshot/matching_network/model/lstm.py(56): forward
                /home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/nn/modules/module.py(726): _slow_forward
                /home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/nn/modules/module.py(742): _call_impl
                /home/davide/Desktop/coefficient/native_oneshot/native_oneshot/matching_network/model/matching.py(70): forward
                /home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/nn/modules/module.py(726): _slow_forward
                /home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/nn/modules/module.py(742): _call_impl
                /home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/jit/_trace.py(940): trace_module
                /home/davide/.virtualenvs/native/lib/python3.7/site-packages/torch/jit/_trace.py(742): trace
                /home/davide/Desktop/coefficient/native_oneshot/native_oneshot/scripts/export-to-traced-model.py(57): <module>
                /home/davide/.pyenv/versions/3.7.8/lib/python3.7/runpy.py(85): _run_code
                /home/davide/.pyenv/versions/3.7.8/lib/python3.7/runpy.py(193): _run_module_as_main
        Comparison exception:   With rtol=0.0001 and atol=1e-05, found 90 element(s) (out of 96) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 11.000000476837158 (-0.947547435760498 vs. -11.947547912597656), which occurred at index (0, 0, 8).

Which says that there is some untraceable code, pointing at the repackage_hidden method of my LSTM. Here is my LSTM module:

from __future__ import annotations

import torch
import torch.nn as nn
from torch.autograd import Variable


class BidirectionalLSTM(nn.Module):
    def __init__(self, layer_size, vector_dim, device):
        super().__init__()
        """
        Initial a muti-layer Bidirectional LSTM
        :param layer_size: a list of each layer'size
        :param batch_size:
        :param vector_dim:
        """
        self.batch_size = 1
        self.hidden_size = layer_size[0]
        self.vector_dim = vector_dim
        self.num_layer = len(layer_size)
        self.lstm = nn.LSTM(
            input_size=self.vector_dim,
            num_layers=self.num_layer,
            hidden_size=self.hidden_size,
            bidirectional=True,
        )
        self.hidden = (
            Variable(
                torch.zeros(
                    self.lstm.num_layers * 2,
                    self.batch_size,
                    self.lstm.hidden_size,
                ),
                requires_grad=False,
            ).to(device),
            Variable(
                torch.zeros(
                    self.lstm.num_layers * 2,
                    self.batch_size,
                    self.lstm.hidden_size,
                ),
                requires_grad=False,
            ).to(device),
        )

    def repackage_hidden(self, h):
        """Wraps hidden states in new Variables,
        to detach them from their history."""
        if type(h) == torch.Tensor:
            return Variable(h.data)
        else:
            return tuple(self.repackage_hidden(v) for v in h)

    def forward(self, inputs):
        inputs = inputs.float()
        self.hidden = self.repackage_hidden(self.hidden)
        output, self.hidden = self.lstm(inputs, self.hidden)
        return output

and my export code:

with torch.jit.optimized_execution(True) and torch.no_grad():
    traced_model = torch.jit.trace(
        model, (support_x, support_y_onehot, target_x, target_y)
    )

I assume that the if else control flow in repackage_hidden might be part of the problem? How to bypass / solve the issue?