Weird Error when using TorchScript

I am implementing a variant of LSTMs using TorchScript by modifying the code in the fastrnn benchmark written by @tom but I am getting a weird error:

RuntimeError: 
Return value was annotated as having type Tuple[Tensor, List[__torch__.model.subLSTM.nn.GRNState]] but is actually of type Tuple[Tensor, List[__torch__.model.subLSTM.nn.GRNState]]:
at ../../src/model/subLSTM/nn.py:214:9
            if i < self.num_layers - 1:
                output = self.dropout_layer(output)

            output_states.append(out_state)
            i += 1

        if self.batch_first:
            output = output.transpose(0, 1)

        return output, output_states
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

Which does not make sense since it is the same type. The code for the mode is the following:

import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.jit as jit
import warnings
from collections import namedtuple
from typing import List, Tuple
from torch import Tensor


GRNState = namedtuple('GRNState', ['hx', 'cx'])


def reverse(lst):
    # type: (List[Tensor]) -> List[Tensor]
    return lst[::-1]


class SubLSTMCell(jit.ScriptModule):
    def __init__(self, input_size, hidden_size):
        super(SubLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.bias_ih = Parameter(torch.randn(4 * hidden_size))
        self.bias_hh = Parameter(torch.randn(4 * hidden_size))

    @jit.script_method
    def forward(self, input: Tensor, state: GRNState) -> Tuple[Tensor, GRNState]:
        hx, cx = state
        gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
                 torch.mm(hx, self.weight_hh.t()) + self.bias_hh).sigmoid()
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        cy = (forgetgate * cx) - (ingate - cellgate)
        hy = outgate - torch.tanh(cy)

        return hy, GRNState(hy, cy)


class LayerNormSubLSTMCell(jit.ScriptModule):
    def __init__(self, input_size, hidden_size):
        super(LayerNormSubLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
        # The layernorms provide learnable biases

        self.layernorm_i = nn.LayerNorm(4 * hidden_size)
        self.layernorm_h = nn.LayerNorm(4 * hidden_size)
        self.layernorm_c = nn.LayerNorm(hidden_size)

    @jit.script_method
    def forward(self, input: Tensor, state: GRNState) -> Tuple[Tensor, GRNState]:
        hx, cx = state
        igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
        hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
        gates = (igates + hgates).sigmoid()
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        cy = self.layernorm_c((forgetgate * cx) + (ingate - cellgate))
        hy = outgate - torch.tanh(cy)

        return hy, GRNState(hy, cy)


class GRNLayer(jit.ScriptModule):
    def __init__(self, cell, *cell_args):
        super(GRNLayer, self).__init__()
        self.cell = cell(*cell_args)

    @jit.script_method
    def forward(self, input: Tensor, state: GRNState) -> Tuple[Tensor, GRNState]:
        inputs = input.unbind(0)
        outputs: List[Tensor] = []
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]
        return torch.stack(outputs), state


class ReverseGRNLayer(jit.ScriptModule):
    def __init__(self, cell, *cell_args):
        super(ReverseGRNLayer, self).__init__()
        self.cell = cell(*cell_args)

    @jit.script_method
    def forward(self, input:Tensor, state:GRNState) -> Tuple[Tensor, GRNState]:
        inputs = reverse(input.unbind(0))
        outputs = jit.annotate(List[Tensor], [])
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]
        return torch.stack(reverse(outputs)), state


class BidirLayer(jit.ScriptModule):
    __constants__ = ['directions']

    def __init__(self, cell, *cell_args):
        super(BidirLayer, self).__init__()
        self.directions = nn.ModuleList([
            GRNLayer(cell, *cell_args),
            ReverseGRNLayer(cell, *cell_args),
        ])

    @jit.script_method
    def forward(self, input: Tensor, states: List[GRNState]) -> Tuple[Tensor, List[GRNState]]:
        outputs: List[Tensor] = []
        output_states: List[GRNState] = []

        i = 0
        for direction in self.directions:
            state = states[i]
            out, out_state = direction(input, state)
            outputs += [out]
            output_states += [out_state]
            i += 1
        return torch.cat(outputs, -1), output_states


def init_stacked_lstm(num_layers, layer, cell, input_size, hidden_size):
    layers = [layer(cell, input_size, hidden_size)] + \
             [layer(cell, hidden_size, hidden_size) for _ in range(num_layers - 1)]
    return nn.ModuleList(layers)


def init_states(num_layers, batch_size, hidden_size, device):
    states: List[GRNState] = []
    temp = torch.randn(num_layers, batch_size, hidden_size,
                       2, device=device).unbind(0)

    for s in temp:
        hx, cx = s.unbind(2)
        states.append(GRNState(hx, cx))

    return states


class SubLSTM(jit.ScriptModule):
    # Necessary for iterating through self.layers and dropout support
    __constants__ = ['layers', 'num_layers', 'batch_first', 'hidden_size']

    def __init__(self, input_size, hidden_size, num_layers, bias=True,
                batch_first=False, dropout=0.0, bidirectional=False,
                layer_norm=False):
        super(SubLSTM, self).__init__()

        layer = BidirLayer if bidirectional else GRNLayer
        cell = LayerNormSubLSTMCell if layer_norm else SubLSTMCell

        self.layers = init_stacked_lstm(
            num_layers, layer, cell, input_size, hidden_size)

        if dropout > 0 and num_layers == 1:
            warnings.warn("dropout lstm adds dropout layers after all but last "
                          "recurrent layer, it expects num_layers greater than "
                          "1, but got num_layers = 1")

        self.dropout_layer = nn.Dropout(dropout)

        self.num_layers = num_layers
        self.batch_first = batch_first
        self.hidden_size = hidden_size


    @jit.script_method
    def forward(self, input: Tensor, states: List[GRNState]=None) -> Tuple[Tensor, List[GRNState]]:
        output = input if not self.batch_first else input.transpose(0, 1)
        output_states: List[GRNState] = []

        if states is None:
            states = init_states(self.num_layers, output.size(1),
                                 self.hidden_size, output.device)

        i = 0
        for rnn_layer in self.layers:
            state = states[i]
            output, out_state = rnn_layer(output, state)

            # Apply the dropout layer except the last layer
            if i < self.num_layers - 1:
                output = self.dropout_layer(output)

            output_states.append(out_state)
            i += 1

        if self.batch_first:
            output = output.transpose(0, 1)

        return output, output_states

I didn’t write that code, I just used it for benchmarking.

So are you using PyTorch 1.3? Maybe making those a plain Tensor tuple instead of a Namedtuple helps.

Best regards

Thomas

What version are you using ? I tried this on master and it worked.

Hi,

Thanks, it was the named tuple. A bit annoying since they can make the code a lot cleaner. By the way I don’t get the same level of performance as in the examples, which is a bit weird.

PS: Sorry for the very delayed reply