Correct way to declare hidden and cell states of LSTM

Simple LSTM Cell like below… I declare my cell state thus…

self.c_t = Variable(torch.zeros(batch_size, cell_size), requires_grad=False).double()

I really don’t like having to do the .double().cuda() on my hidden Variable. But if I dont, the model breaks…

Is there a way to fix this… I tried doing Parameters, but the LSTMCell returns a Variable, so I got a type error.

What the “correct” way to setup hidden variables for LSTMCell?

import torch
from torch import nn
from torch.nn import Parameter
from torch.autograd import Variable


class SimpleLSTM(nn.Module):
    def __init__(self, batch_size, input_dims, sequence_length, cell_size, output_features=1):
        super(SimpleLSTM, self).__init__()
        self.input_dims = input_dims
        self.sequence_length = sequence_length
        self.cell_size = cell_size
        self.lstm = nn.LSTMCell(input_dims, cell_size)
        self.to_output = nn.Linear(cell_size, output_features)
        self.h_t = Variable(torch.zeros(batch_size, cell_size), requires_grad=False).double()
        self.c_t = Variable(torch.zeros(batch_size, cell_size), requires_grad=False).double()


    def forward(self, input):

        self.h_t.zero_()
        self.c_t.zero_()

        outputs = []

        for input_t in torch.chunk(input, self.sequence_length, dim=2):
            self.h_t, self.c_t = self.lstm(input_t.squeeze(2), (self.h_t, self.c_t))
            outputs.append(self.to_output(self.h_t))

        return torch.cat(outputs, dim=1)

Hi Duane!

In most cases you can side step this issue by using nn.LSTM instead of nn.LSTMCell

docs: http://pytorch.org/docs/0.3.1/nn.html#lstm

nn.LSTM take your full sequence (rather than chunks), automatically initializes the hidden and cell states to zeros, runs the lstm over your full sequence (updating state along the way) and returns a final list of outputs and final hidden/cell state.

If you do need to initialize a hidden state because you’re decoding one item at a time or some similar situation,
I usually make a method like this:

def init_hidden(self, batch_size):
    hidden = Variable(next(self.parameters()).data.new(self.num_layers, batch_size, self.hidden_size))
    cell =  Variable(next(self.parameters()).data.new(self.num_layers, batch_size, self.hidden_size))
    return (hidden, cell)

next(self.parameters()).data.new() looks arcane but all it’s doing is grabbing the first parameter in the model and making a new tensor of the same type with specified dimensions. This way, if you call .cuda() on the model it’l return cuda tensors instead.

7 Likes

Austin! Thanks buddy :smile:

Great advice as always… here’s the grad-checked code I ended up with.

import torch
from torch import nn
from torch.autograd import Variable


class SimpleLSTM(nn.Module):
    def __init__(self, batch_size, input_dims, sequence_length, cell_size, output_features=1):
        super(SimpleLSTM, self).__init__()
        self.input_dims = input_dims
        self.sequence_length = sequence_length
        self.cell_size = cell_size
        self.lstm = nn.LSTMCell(input_dims, cell_size)
        self.to_output = nn.Linear(cell_size, output_features)

    def forward(self, input):

        h_t, c_t = self.init_hidden(input.size(0))

        outputs = []

        for input_t in torch.chunk(input, self.sequence_length, dim=2):
            h_t, c_t = self.lstm(input_t.squeeze(2), (h_t, c_t))
            outputs.append(self.to_output(h_t))

        return torch.cat(outputs, dim=1)

    def init_hidden(self, batch_size):
        hidden = Variable(next(self.parameters()).data.new(batch_size, self.cell_size), requires_grad=False)
        cell = Variable(next(self.parameters()).data.new(batch_size, self.cell_size), requires_grad=False)
        return hidden.zero_(), cell.zero_()

Hi Austin, does this not mean that the initial cell state state and hidden state is different for each element in the batch? For example if I change the order of examples given as input to the network the outputs are going to be different right? Also, shouldn’t required_grad be set to True?

1 Like