Dear community,
I am currently writing my own rnn for a smple sqeuential mnist classification task and while I do not experience an issure during training, I am wondering why the following if statement in the forward pass
in the RnnCell Class
is not triggering:
if hidden_state is None:
hidden_state = torch.zeros(input.shape[0], self.hidden_size).to(device)
print("If statement triggered!")
In the forward the hidden_state
is by default set to None
, so the statement should in theory trigger, but the print statement within the does not get printed. Am I missing something? Please let me know. A runnable notebook or collab example can be found below by simply copy pasting it into your notebook. A small test case is included as well.
Any hints or thoughts would be appreciated. I dont think I need to initalise the hidden_state within the RnnCell as training works fine, I would however like to know what I am doing wrong to learn.
Best,
weight_theta
import torch
from torch import nn
from torch.autograd import Variable
from torch.autograd import Variable
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class RnnCell(nn.Module):
def __init__(self, input_size, hidden_size, activation="tanh"):
super(RnnCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.activation = activation
if self.activation not in ["tanh", "relu", "sigmoid"]:
raise ValueError("Invalid nonlinearity selected for RNN. Please use tanh, relu or sigmoid.")
self.input2hidden = nn.Linear(input_size, hidden_size)
# hidden2hidden when we have more than 1 RNN stacked
# hidden2out when we have only 1 RNN
self.hidden2hidden = nn.Linear(hidden_size, hidden_size)
def forward(self, input, hidden_state=None):
'''
Inputs: input (torch tensor) of shape [batchsize, input_size]
hidden state (torch tensor) of shape [batchsize, hiddensize]
Output: output (torch tensor) of shape [batchsize, hiddensize ]
'''
# initalise hidden state at first iteration so if none
if hidden_state is None:
hidden_state = torch.zeros(input.shape[0], self.hidden_size).to(device)
print("If statement triggered!")
hidden_state = (self.input2hidden(input) + self.hidden2hidden(hidden_state))
# takes output from hidden and apply activation
if self.activation == "tanh":
out = torch.tanh(hidden_state)
elif self.activation == "relu":
out = torch.relu(hidden_state)
elif self.activation == "sigmoid":
out = torch.sigmoid(hidden_state)
return out
def init_weights_normal(self):
# iterate over parameters or weights theta
# and initalise them with a normal centered at 0 with 0.02 spread.
for weight in self.parameters():
weight.data.normal_(0, 0.02)
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, activation='relu'):
super(SimpleRNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.output_size = output_size
self.rnn_cell_list = nn.ModuleList()
if activation == 'tanh':
self.rnn_cell_list.append(RnnCell(self.input_size,
self.hidden_size,
"tanh"))
for l in range(1, self.num_layers):
self.rnn_cell_list.append(RnnCell(self.hidden_size,
self.hidden_size,
"tanh"))
elif activation == 'relu':
self.rnn_cell_list.append(RnnCell(self.input_size,
self.hidden_size,
"relu"))
for l in range(1, self.num_layers):
self.rnn_cell_list.append(RnnCell(self.hidden_size,
self.hidden_size,
"relu"))
elif activation == 'sigmoid':
self.rnn_cell_list.append(RnnCell(self.input_size,
self.hidden_size,
"sigmoid"))
for l in range(1, self.num_layers):
self.rnn_cell_list.append(RnnCell(self.hidden_size,
self.hidden_size,
"sigmoid"))
else:
raise ValueError("Invalid activation. Please use tanh, relu or sigmoid activation.")
self.fc = nn.Linear(self.hidden_size, self.output_size)
#self.sigmoid = nn.Sigmoid()
def forward(self, input, hidden_state=None):
'''
Inputs: input (torch tensor) of shape [batchsize, seqence length, inputsize]
Output: output (torch tensor) of shape [batchsize, outputsize]
'''
if hidden_state is None:
if torch.cuda.is_available():
h0 = Variable(torch.zeros(self.num_layers, input.size(0), self.hidden_size).cuda())
else:
h0 = Variable(torch.zeros(self.num_layers, input.size(0), self.hidden_size))
else:
h0 = hidden_state
outs = []
hidden = list()
for layer in range(self.num_layers):
hidden.append(h0[layer, :, :])
for t in range(input.size(1)):
for layer in range(self.num_layers):
if layer == 0:
hidden_l = self.rnn_cell_list[layer](input[:, t, :], hidden[layer])
else:
hidden_l = self.rnn_cell_list[layer](hidden[layer - 1],hidden[layer])
hidden[layer] = hidden_l
hidden[layer] = hidden_l
outs.append(hidden_l)
# select last time step indexed at [-1]
out = outs[-1].squeeze()
#out = nn.Sigmoid(out)
out = self.fc(out)
return out
def test ():
# batch size, sequence length, input size
model = SimpleRNN(input_size=28*28, hidden_size=128, num_layers=3, output_size=10)
model = model.to(device)
x = torch.randn(64, 28*28)
x = x.unsqueeze(-1)
vals = torch.ones(64, 28*28, 28*28-1) * (28*28)
x = torch.cat([x, vals], dim=-1).to(device)
out = model(x)
xshape = out.shape
return x, xshape
testx, xdims = test()
print("Size test: passed.")