If statement for initalising zero valued hidden state matrix in RnnCell not triggering

Dear community,

I am currently writing my own rnn for a smple sqeuential mnist classification task and while I do not experience an issure during training, I am wondering why the following if statement in the forward pass in the RnnCell Class is not triggering:

if hidden_state is None:
            hidden_state = torch.zeros(input.shape[0], self.hidden_size).to(device)
            print("If statement triggered!")

In the forward the hidden_state is by default set to None, so the statement should in theory trigger, but the print statement within the does not get printed. Am I missing something? Please let me know. A runnable notebook or collab example can be found below by simply copy pasting it into your notebook. A small test case is included as well.

Any hints or thoughts would be appreciated. I dont think I need to initalise the hidden_state within the RnnCell as training works fine, I would however like to know what I am doing wrong to learn.

Best,

weight_theta

import torch
from torch import nn
from torch.autograd import Variable
from torch.autograd import Variable
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class RnnCell(nn.Module):
    def __init__(self, input_size, hidden_size, activation="tanh"):
        super(RnnCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.activation = activation
        if self.activation not in ["tanh", "relu", "sigmoid"]:
            raise ValueError("Invalid nonlinearity selected for RNN. Please use tanh, relu or sigmoid.")

        self.input2hidden = nn.Linear(input_size, hidden_size)
        # hidden2hidden when we have more than 1 RNN stacked
        # hidden2out when we have only 1 RNN
        self.hidden2hidden = nn.Linear(hidden_size, hidden_size)

    def forward(self, input, hidden_state=None):
        '''
        Inputs: input (torch tensor) of shape [batchsize, input_size]
                hidden state (torch tensor) of shape [batchsize, hiddensize]
        Output: output (torch tensor) of shape [batchsize, hiddensize ]
        '''

        # initalise hidden state at first iteration so if none
        if hidden_state is None:
            hidden_state = torch.zeros(input.shape[0], self.hidden_size).to(device)
            print("If statement triggered!")

        hidden_state = (self.input2hidden(input) + self.hidden2hidden(hidden_state))

        # takes output from hidden and apply activation
        if self.activation == "tanh":
            out = torch.tanh(hidden_state)
        elif self.activation == "relu":
            out = torch.relu(hidden_state)
        elif self.activation == "sigmoid":
            out = torch.sigmoid(hidden_state) 
        return out

    def init_weights_normal(self):
      # iterate over parameters or weights theta
      # and initalise them with a normal centered at 0 with 0.02 spread.
      for weight in self.parameters():
          weight.data.normal_(0, 0.02)

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, activation='relu'):
        super(SimpleRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size

        self.rnn_cell_list = nn.ModuleList()

        if activation == 'tanh':
            self.rnn_cell_list.append(RnnCell(self.input_size,
                                                   self.hidden_size,
                                                   "tanh"))
            for l in range(1, self.num_layers):
                self.rnn_cell_list.append(RnnCell(self.hidden_size,
                                                       self.hidden_size,
                                                       "tanh"))

        elif activation == 'relu':
            self.rnn_cell_list.append(RnnCell(self.input_size,
                                                   self.hidden_size,
                                                   "relu"))
            for l in range(1, self.num_layers):
                self.rnn_cell_list.append(RnnCell(self.hidden_size,
                                                   self.hidden_size,
                                                   "relu"))

        elif activation == 'sigmoid':
            self.rnn_cell_list.append(RnnCell(self.input_size,
                                                   self.hidden_size,
                                                   "sigmoid"))
            for l in range(1, self.num_layers):
                self.rnn_cell_list.append(RnnCell(self.hidden_size,
                                                   self.hidden_size,
                                                   "sigmoid"))
        else:
            raise ValueError("Invalid activation. Please use tanh, relu or sigmoid activation.")

        self.fc = nn.Linear(self.hidden_size, self.output_size)
        #self.sigmoid = nn.Sigmoid()

    def forward(self, input, hidden_state=None):
        '''
        Inputs: input (torch tensor) of shape [batchsize, seqence length, inputsize]
        Output: output (torch tensor) of shape [batchsize, outputsize]
        '''

        if hidden_state is None:
            if torch.cuda.is_available():
                h0 = Variable(torch.zeros(self.num_layers, input.size(0), self.hidden_size).cuda())
            else:
                h0 = Variable(torch.zeros(self.num_layers, input.size(0), self.hidden_size))

        else:
             h0 = hidden_state

        outs = []

        hidden = list()
        for layer in range(self.num_layers):
            hidden.append(h0[layer, :, :])

        for t in range(input.size(1)):

            for layer in range(self.num_layers):

                if layer == 0:
                    hidden_l = self.rnn_cell_list[layer](input[:, t, :], hidden[layer])
                else:
                    hidden_l = self.rnn_cell_list[layer](hidden[layer - 1],hidden[layer])
                hidden[layer] = hidden_l

                hidden[layer] = hidden_l

            outs.append(hidden_l)

        # select last time step indexed at [-1]
        out = outs[-1].squeeze()
        #out = nn.Sigmoid(out)
        out = self.fc(out)
        return out

def test ():
  # batch size, sequence length, input size
    model = SimpleRNN(input_size=28*28, hidden_size=128, num_layers=3, output_size=10)
    model = model.to(device)
    x = torch.randn(64, 28*28)
    x = x.unsqueeze(-1)
    vals = torch.ones(64, 28*28, 28*28-1) * (28*28)
    x = torch.cat([x, vals], dim=-1).to(device)
    out = model(x)
    xshape = out.shape
    return x, xshape

testx, xdims = test()
print("Size test: passed.")

You are calling the RnnCells with a hidden state via:

if layer == 0:
    hidden_l = self.rnn_cell_list[layer](input[:, t, :], hidden[layer])
else:
    hidden_l = self.rnn_cell_list[layer](hidden[layer - 1],hidden[layer])

which is previously defined as:

        if hidden_state is None:
            if torch.cuda.is_available():
                h0 = torch.zeros(self.num_layers, input.size(0), self.hidden_size).cuda()
            else:
                h0 = torch.zeros(self.num_layers, input.size(0), self.hidden_size)

        else:
             h0 = hidden_state

        outs = []

        hidden = list()
        for layer in range(self.num_layers):
            hidden.append(h0[layer, :, :])

and thus never None.
If I remove the hidden[layer] from self.rnn_cell_list[layer](input[:, t, :], hidden[layer]) I see the print statements as expected.

1 Like

Thank you for your quick reply Patrick. So the hidden state matrix is already being initalised to be of torch.zeroe at h_0 and therefore as you say never None. This explains why the if hidden_state is None: never triggers and would mean that the if clause I added is redundant and can be removed. Great, thank you for pointing my mistake out to me.