Multilayer RNN using RNNCell

Hey all,

I am trying to implement a fully connected multilayer RNN using torch.nn.RNNCell.

I have implemented it, but it looks like it is not working.

Here is the code for reference:

class MultilayerRNN(nn.Module):
    def __init__(self, input_size, output_size, hidden_layers = [100, 100, 100]):
        super(NormalDeepRNN, self).__init__()
        assert len(hidden_layers) > 0

        self.hidden_layers = hidden_layers
        self.num_layers = len(hidden_layers)
        self.input_size = input_size

        self.embedding = nn.Embedding(self.input_size, self.hidden_layers[0])
        self.layer_out = nn.Linear(hidden_layers[-1], num_classes)
        
    def forward(self, input):
        # input.shape: [BATCH_SIZE, INPUT_SIZE], e.g. [16, 128]

        # After trnaspose,
        # input.shape: [INPUT_SIZE, BATCH_SIZE], e.g. [128, 16]
        input = input.t()

        # After embedding,
        # input.shape: [INPUT_SIZE, BATCH_SIZE, EMBEDDING_DIM], e.g. [128, 16, 100]
        input = self.embedding(input)

        for i, out_size in enumerate(self.hidden_layers):
            if i == 0:
               # For first hidden layer, input_size == embedding_dim
                in_size = input.size(2)
            else:
                in_size = self.hidden_layers[i-1]

            rnn = nn.RNNCell(in_size, out_size, 'tanh')
            hx = self.__init_hidden(input.size(1), in_size)

            input = self.__cell(rnn, input, hx)

        output = self.layer_out(input[-1])
        return output, input[-1]
            

    def __cell(self, cell, input, hx):
        inputs = input.unbind(0)
        outputs= []

        for i in range(len(inputs)):
            hx = cell(inputs[i], hx)
            outputs += [hx]

        return torch.stack(outputs)

    def __init_hidden(self, batch_size, hidden_size):
        hidden = torch.zeros(batch_size, hidden_size)
        return Variable(hidden)

It always gives accuracy around 50%, which is baseline considering I am trying binary classification.

The same dataset gives around 85% accuracy with standard RNN modules (e.g. nn.RNNBase or nn.RNN).

I am pretty sure there is something wrong with my forward method, but I am unable to point it out.

Any help is appreciated.
Thank you.

Best,
Harshil

Hey all,
Just as an update, I solved this problem by changing the initialization of nn.RNNCell from forward to __init__ method.
Thank you.