TorchScript Inference fails with torch.jit.trace for seq2seq PyTorch model

Issue: Inference fails with torch.jit.trace on seq2seq PyTorch model using Bi-LSTM and LSTM

Model architecture: StackPropagation-SLU/module.py at master · LeePleased/StackPropagation-SLU · GitHub

Model Input: Tuple of Input Tensors (Tensor1, Tensor2) where,
Tensor1 → Words of an Input sentences mapped to their Ids from dictionary
Tensor2 → is the length of original input text.

//Model Instance
import torch
from utils.module import ModelManager
path = ‘save/model/model.pt’
kwargs, state = torch.load(path, map_location=torch.device(‘cpu’))
model = ModelManager(**kwargs)
model.eval()
model.load_state_dict(state)

//Model Input Sample
text = “bring some clay on your return please”
length = 30
X, L = tokenize(text)
X = torch.tensor(X, dtype=torch.long).unsqueeze(0)
L = torch.tensor(L, dtype=torch.long).unsqueeze(0)
inputs = (X,L)

//Torch JIT Trace
traced_script_module = torch.jit.trace(model, inputs)
torch.jit.save(traced_script_module, ‘jitTest.pt’)
res = traced_script_module(X,L) → Throws error during inference for different values of “L” (original text length)

//Error stack trace

  1. When Input length is more than the sample input given during tracing the model:
    Input: "text = “bring some clay on your return to home please”
    //error
    RuntimeError: The following operation failed in the TorchScript interpreter
    Traceback of TorchScript (most recent call last):
    module.py(305): forward
    Anaconda\lib\site-packages\torch\nn\modules\module.py(1118): _slow_forward
    Anaconda\lib\site-packages\torch\nn\modules\module.py(1130): _call_impl
    module.py(113): forward
    RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 9 but got size 7 for tensor number 1 in the list.

  2. When Input length is less than the sample input given during tracing the model:
    Input: "text = “bring some clay on your return”
    //error
    RuntimeError: The following operation failed in the TorchScript interpreter.
    Traceback of TorchScript (most recent call last):
    module.py(347): forward
    Anaconda\lib\site-packages\torch\nn\modules\module.py(1118): _slow_forward
    Anaconda\lib\site-packages\torch\nn\modules\module.py(1130): _call_impl
    module.py(102): forward
    RuntimeError: index 6 is out of bounds for dimension 0 with size 6

Observation:

  • The issue comes form the decoder part of the model.
  • Not able to find any function or method calls in model definition which makes the Input Tensor shape dimension constant
  • The issue arises due to freezing of the dimension of input tensor during tracing which ideally should not happen.
  • torch.jit.script is not possible here as the model involves inheritance etc and is too complex

Kindly request to help if it is a bug from the torchscript tracing module (or) any method or function calls in model architecture is missed which is making the input dimension of tensor constant thereby giving no chance for dynamism. Thanks in advance.