I am implementing a variant of LSTMs using TorchScript by modifying the code in the fastrnn benchmark written by @tom but I am getting a weird error:
RuntimeError:
Return value was annotated as having type Tuple[Tensor, List[__torch__.model.subLSTM.nn.GRNState]] but is actually of type Tuple[Tensor, List[__torch__.model.subLSTM.nn.GRNState]]:
at ../../src/model/subLSTM/nn.py:214:9
if i < self.num_layers - 1:
output = self.dropout_layer(output)
output_states.append(out_state)
i += 1
if self.batch_first:
output = output.transpose(0, 1)
return output, output_states
~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
Which does not make sense since it is the same type. The code for the mode is the following:
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.jit as jit
import warnings
from collections import namedtuple
from typing import List, Tuple
from torch import Tensor
GRNState = namedtuple('GRNState', ['hx', 'cx'])
def reverse(lst):
# type: (List[Tensor]) -> List[Tensor]
return lst[::-1]
class SubLSTMCell(jit.ScriptModule):
def __init__(self, input_size, hidden_size):
super(SubLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
self.bias_ih = Parameter(torch.randn(4 * hidden_size))
self.bias_hh = Parameter(torch.randn(4 * hidden_size))
@jit.script_method
def forward(self, input: Tensor, state: GRNState) -> Tuple[Tensor, GRNState]:
hx, cx = state
gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
torch.mm(hx, self.weight_hh.t()) + self.bias_hh).sigmoid()
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
cy = (forgetgate * cx) - (ingate - cellgate)
hy = outgate - torch.tanh(cy)
return hy, GRNState(hy, cy)
class LayerNormSubLSTMCell(jit.ScriptModule):
def __init__(self, input_size, hidden_size):
super(LayerNormSubLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
# The layernorms provide learnable biases
self.layernorm_i = nn.LayerNorm(4 * hidden_size)
self.layernorm_h = nn.LayerNorm(4 * hidden_size)
self.layernorm_c = nn.LayerNorm(hidden_size)
@jit.script_method
def forward(self, input: Tensor, state: GRNState) -> Tuple[Tensor, GRNState]:
hx, cx = state
igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
gates = (igates + hgates).sigmoid()
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
cy = self.layernorm_c((forgetgate * cx) + (ingate - cellgate))
hy = outgate - torch.tanh(cy)
return hy, GRNState(hy, cy)
class GRNLayer(jit.ScriptModule):
def __init__(self, cell, *cell_args):
super(GRNLayer, self).__init__()
self.cell = cell(*cell_args)
@jit.script_method
def forward(self, input: Tensor, state: GRNState) -> Tuple[Tensor, GRNState]:
inputs = input.unbind(0)
outputs: List[Tensor] = []
for i in range(len(inputs)):
out, state = self.cell(inputs[i], state)
outputs += [out]
return torch.stack(outputs), state
class ReverseGRNLayer(jit.ScriptModule):
def __init__(self, cell, *cell_args):
super(ReverseGRNLayer, self).__init__()
self.cell = cell(*cell_args)
@jit.script_method
def forward(self, input:Tensor, state:GRNState) -> Tuple[Tensor, GRNState]:
inputs = reverse(input.unbind(0))
outputs = jit.annotate(List[Tensor], [])
for i in range(len(inputs)):
out, state = self.cell(inputs[i], state)
outputs += [out]
return torch.stack(reverse(outputs)), state
class BidirLayer(jit.ScriptModule):
__constants__ = ['directions']
def __init__(self, cell, *cell_args):
super(BidirLayer, self).__init__()
self.directions = nn.ModuleList([
GRNLayer(cell, *cell_args),
ReverseGRNLayer(cell, *cell_args),
])
@jit.script_method
def forward(self, input: Tensor, states: List[GRNState]) -> Tuple[Tensor, List[GRNState]]:
outputs: List[Tensor] = []
output_states: List[GRNState] = []
i = 0
for direction in self.directions:
state = states[i]
out, out_state = direction(input, state)
outputs += [out]
output_states += [out_state]
i += 1
return torch.cat(outputs, -1), output_states
def init_stacked_lstm(num_layers, layer, cell, input_size, hidden_size):
layers = [layer(cell, input_size, hidden_size)] + \
[layer(cell, hidden_size, hidden_size) for _ in range(num_layers - 1)]
return nn.ModuleList(layers)
def init_states(num_layers, batch_size, hidden_size, device):
states: List[GRNState] = []
temp = torch.randn(num_layers, batch_size, hidden_size,
2, device=device).unbind(0)
for s in temp:
hx, cx = s.unbind(2)
states.append(GRNState(hx, cx))
return states
class SubLSTM(jit.ScriptModule):
# Necessary for iterating through self.layers and dropout support
__constants__ = ['layers', 'num_layers', 'batch_first', 'hidden_size']
def __init__(self, input_size, hidden_size, num_layers, bias=True,
batch_first=False, dropout=0.0, bidirectional=False,
layer_norm=False):
super(SubLSTM, self).__init__()
layer = BidirLayer if bidirectional else GRNLayer
cell = LayerNormSubLSTMCell if layer_norm else SubLSTMCell
self.layers = init_stacked_lstm(
num_layers, layer, cell, input_size, hidden_size)
if dropout > 0 and num_layers == 1:
warnings.warn("dropout lstm adds dropout layers after all but last "
"recurrent layer, it expects num_layers greater than "
"1, but got num_layers = 1")
self.dropout_layer = nn.Dropout(dropout)
self.num_layers = num_layers
self.batch_first = batch_first
self.hidden_size = hidden_size
@jit.script_method
def forward(self, input: Tensor, states: List[GRNState]=None) -> Tuple[Tensor, List[GRNState]]:
output = input if not self.batch_first else input.transpose(0, 1)
output_states: List[GRNState] = []
if states is None:
states = init_states(self.num_layers, output.size(1),
self.hidden_size, output.device)
i = 0
for rnn_layer in self.layers:
state = states[i]
output, out_state = rnn_layer(output, state)
# Apply the dropout layer except the last layer
if i < self.num_layers - 1:
output = self.dropout_layer(output)
output_states.append(out_state)
i += 1
if self.batch_first:
output = output.transpose(0, 1)
return output, output_states