Extending LSTM with additional Gates

I am new to Pytorch and would appreciate some direction on how to create and use an LSTM cell with multiple additional gates. For example I would like to implement the LSTM cell described in the this paper

You just take an LSTMCell implementation and modify it.
For example, you can copy the implementation here and modify it for your own purposes:

1 Like

very much appreciated!

just a note that fastrnn module seems to only work with pytorch v 1.0.0

I am having a little difficulty getting the implementation of custom lstm cell a la https://github.com/pytorch/benchmark/tree/master/rnns/fastrnns to work. Please my code with minimal implementation below
custom RNN

import torch
from collections import namedtuple

Define a creator as a function:
(options) -> (inputs, params, forward, backward_setup, backward)
inputs: the inputs to the returned 'forward'. One can call
    forward(*inputs) directly.
params: List[Tensor] all requires_grad=True parameters.
forward: function / graph executor / module
    One can call rnn(rnn_inputs) using the outputs of the creator.
backward_setup: backward_inputs = backward_setup(*outputs)
    Then, we pass backward_inputs to backward. If None, then it is assumed to
    be the identity function.
backward: Given `output = backward_setup(*forward(*inputs))`, performs
    backpropagation. If None, then nothing happens.

fastrnns.bench times the forward and backward invocations.

def lstm_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
    # type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
    hx, cx = hidden
    gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh

    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

    ingate = torch.sigmoid(ingate)
    forgetgate = torch.sigmoid(forgetgate)
    cellgate = torch.tanh(cellgate)
    outgate = torch.sigmoid(outgate)

    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * torch.tanh(cy)

    return hy, cy

def varlen_lstm_creator(script=False, **kwargs):
    sequences, _,  hidden, params, _ = varlen_lstm_inputs(
        return_module=False, **kwargs)
    inputs = [sequences, hidden] + params[0]
    return ModelDef(
        forward=varlen_lstm_factory(lstm_cell, script),

ModelDef = namedtuple('ModelDef', [
    'inputs', 'params', 'forward', 'backward_setup', 'backward'])

def varlen_lstm_inputs(minlen=30, maxlen=100,
                       numLayers=1, inputSize=512, hiddenSize=512,
                       miniBatch=64, return_module=False, device='cpu',
                       seed=None, **kwargs):
    if seed is not None:
    lengths = torch.randint(
        low=minlen, high=maxlen, size=[miniBatch],
        dtype=torch.long, device=device)
    x = [torch.randn(length, inputSize, device=device)
         for length in lengths]
    hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
    cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
    lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers).to(device)

    if return_module:
        return x, lengths, (hx, cx), lstm.all_weights, lstm
        # NB: lstm.all_weights format:
        # wih, whh, bih, bhh = lstm.all_weights[layer]
        return x, lengths, (hx, cx), lstm.all_weights, None

def varlen_lstm_backward_setup(forward_output, seed=None):
    if seed:
    rnn_utils = torch.nn.utils.rnn
    sequences = forward_output[0]
    padded = rnn_utils.pad_sequence(sequences)
    grad = torch.randn_like(padded)
    return padded, grad

def varlen_lstm_factory(cell, script):
    def dynamic_rnn(sequences, hiddens, wih, whh, bih, bhh):
        # type: (List[Tensor], Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]]
        hx, cx = hiddens
        hxs = hx.unbind(1)
        cxs = cx.unbind(1)
        # List of: (output, hx, cx)
        outputs = []
        hx_outs = []
        cx_outs = []

        for batch in range(len(sequences)):
            output = []
            hy, cy = hxs[batch], cxs[batch]
            inputs = sequences[batch].unbind(0)

            for seq_idx in range(len(inputs)):
                hy, cy = cell(
                    inputs[seq_idx].unsqueeze(0), (hy, cy), wih, whh, bih, bhh)
                output += [hy]
            outputs += [torch.stack(output)]
            hx_outs += [hy.unsqueeze(0)]
            cx_outs += [cy.unsqueeze(0)]

        return outputs, (hx_outs, cx_outs)

    if script:
        cell = torch.jit.script(cell)
        dynamic_rnn = torch.jit.script(dynamic_rnn)

    return dynamic_rnn

def simple_backward(output, grad_output):
    return output.backward(grad_output)

# list[list[T]] -> list[T]
def flatten_list(lst):
    result = []
    for inner in lst:
    return result

a basic POF

import torch
import torch.nn as nn
from CustomRNN import varlen_lstm_creator


lstm = varlen_lstm_creator(inputSize=3, hiddenSize=3)  # Input dim is 3, output dim is 3
inputs = [torch.randn(1, 3) for _ in range(5)]  # make a sequence of length 5

# initialize the hidden state.
hidden = (torch.randn(1, 1, 3),
          torch.randn(1, 1, 3))
for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = lstm(i.view(1, 1, -1), hidden)

inputs = torch.cat(inputs).view(len(inputs), 1, -1)
hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3))  # clean out hidden state
out, hidden = lstm(inputs, hidden)

TypeError: ‘ModelDef’ object is not callable

Any help is grealty appreciated!

The whole file / helper functions in there are not callable as-is, it’s part of a larger benchmark harness.

I meant that you probably want to just grab the cell functions themselves.

OK. But how do I integrate a cell function with for example torch.nn.LSTM? I appreciate your time.

you dont integrate a cell function with torch.nn.LSTM, you write a for loop around it.

Something like:

for t in range(timesteps):
    for b in range(batches):
        h, c = F.lstm_cell(input, weight, bias, h, c)
        out = F.linear(weight_out, h)
1 Like

This would be an example of such a loop that you would write: https://github.com/pytorch/benchmark/blob/09eaadc1d05ad442b1f0beb82babf875bbafb24b/rnns/fastrnns/factory.py#L165-L182

I was able to create an use a number of custom RNN/ LSTM cell using https://github.com/NVIDIA/apex/tree/master/apex/RNN.

I am struggling to with a custom lstm cell that returns tuples for both the hidden and cell state. Please find minimal code below. Any help is appreciated.

class skipLSTMRNNCell(RNNCell):

    def __init__(self, input_size, hidden_size, bias=False, output_size=None):
        gate_multiplier = 4
        super(skipLSTMRNNCell, self).__init__(gate_multiplier, input_size, hidden_size, SkipLSTMCell, n_hidden_states=4,
                                           bias=bias, output_size=output_size)

        self.w_uh = Parameter(xavier_uniform(torch.Tensor(1, hidden_size)))
        self.b_uh = Parameter(torch.ones(1))


    def forward(self, input):
        # if not inited or bsz has changed this will create hidden states

        hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden

        self.hidden = list(
            self.cell(input, hidden_state, self.w_ih, self.w_hh, self.w_uh, self.b_uh,
                      b_ih=self.b_ih, b_hh=self.b_hh)

        if self.output_size != self.hidden_size:
            self.hidden[0] = F.linear(self.hidden[0], self.w_ho)
        return tuple(self.hidden)

    def new_like(self, new_input_size=None):
        if new_input_size is None:
            new_input_size = self.input_size

        return type(self)(



def SkipLSTMCell(input, hidden, w_ih, w_hh, w_uh, b_uh=None, b_ih=None, b_hh=None):
    # if num_layers != 1:
    #     raise RuntimeError("wrong num_layers: got {}, expected {}".format(num_layers, 1))
    # w_ih, w_hh = w_ih[0], w_hh[0]
    # b_ih = b_ih[0] if b_ih is not None else None
    # b_hh = b_hh[0] if b_hh is not None else None

    c_prev, h_prev, update_prob_prev, cum_update_prob_prev = hidden

    gates = F.linear(input, w_ih, b_ih) + F.linear(h_prev, w_hh, b_hh)

    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)


    ingate = F.sigmoid(ingate)
    forgetgate = F.sigmoid(forgetgate)
    cellgate = torch.tanh(cellgate)
    outgate = F.sigmoid(outgate)

    new_c_tilde = (forgetgate * c_prev) + (ingate * cellgate)
    new_h_tilde = outgate * torch.tanh(new_c_tilde)
    # Compute value for the update prob
    new_update_prob_tilde = torch.sigmoid(F.linear(new_c_tilde, w_uh, b_uh))
    # Compute value for the update gate
    cum_update_prob = cum_update_prob_prev + torch.min(update_prob_prev, 1. - cum_update_prob_prev)
    # round
    bn = BinaryLayer()
    update_gate = bn(cum_update_prob)
    # Apply update gate
    new_c = update_gate * new_c_tilde + (1. - update_gate) * c_prev
    new_h = update_gate * new_h_tilde + (1. - update_gate) * h_prev
    new_cum_update_prob = update_gate * 0. + (1. - update_gate) * cum_update_prob
    new_update_prob = update_gate * new_update_prob_tilde + (1. - update_gate) * update_prob_prev

    new_state = (new_c, new_h, new_update_prob, new_cum_update_prob)
    new_output = (new_h, update_gate)

    return new_output, new_state

Can anyone explain what this is line doing?

ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)