How to decide input and hidden layer dimension to torch.nn.RNN?

I am trying to write a binary addition code, I have to provide two bits at a time so input shape should
be (1,2) and I am taking hidden layer size 16

rnn = nn.RNN(2, 16, 1)
input = torch.randn(1, 2)
h0 = torch.randn(2, 16)
output, hn = rnn(input, h0)

I am getting error input must have 3 dimensions, got 2 when defining input ,

1 Like

In the default setup your input should have the shape [seq_len, batch_size, features].
If you want to provide the two bits sequentially, you should pass it as [2, 1, 1].
Also, you can specify batch_first=True to pass the input as [batch_size, seq_len, features].

2 Likes

I got it but how can I decide the shape of hidden dimension, I am confused
" suppose I have 2d input with shape (1,2) and hidden size 16 and output with shape (1,1), I can say hidden shape is (2,16) but for 3d if input shape is (1,1,2), [batch_size, seq_len, features] & output (1,1,1) how can i calculate the shape of hidden layer

You can chose the hidden size as you wish.
The output will have the shape [seq_len, batch_size, hidden_size].
Here is a small example:

seq_len = 2
features = 1
batch_size = 5
hidden_size = 10
num_layers = 1

model = nn.RNN(
    input_size=features,
    hidden_size=hidden_size,
    num_layers=num_layers)

x = torch.randn(seq_len, batch_size, features)
h0 = torch.zeros(num_layers, batch_size, hidden_size)
output, hn = model(x, h0)
print(output.shape)
print(hn.shape)

If you want to transform the output to a target space, you can apply another linear layer as described in this tutorial. This would create a model as described in CS231n - Slide 36.

4 Likes

I used following class to define RNN

class RNN_Model(torch.nn.Module):

    def __init__(self, input_size, rnn_hidden_size, output_size):
        super(RNN_Model, self).__init__()
        self.rnn = torch.nn.RNN(input_size, rnn_hidden_size,
                                num_layers=1, nonlinearity='relu',
                                batch_first=True)
        self.h_0 = self.initialize_hidden(rnn_hidden_size)
        self.linear = torch.nn.Linear(rnn_hidden_size, output_size)
        self.sigmoid= torch.nn.Sigmoid()
    def forward(self, x):
        x = x.unsqueeze(0)
        self.rnn.flatten_parameters()
        out, self.h_0 = self.rnn(x, self.h_0)
        out = self.linear(out)
        out= self.sigmoid(out)
        return out

    def initialize_hidden(self, rnn_hidden_size):
        # n_layers * n_directions, batch_size, rnn_hidden_size
        return Variable(torch.randn(1,1,rnn_hidden_size),
                        requires_grad=True)

and here how I am training

# # training logic
def train(model,x,y,criterion,optimizer):
#     '''
#     train NN
#     '''
    model.train()
    y_pred = model(x)
    loss = criterion(y_pred[0], y)
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()
    out=0
    if y_pred[[[0]]] >= 0.5:
        out=1
    return loss.item(),out
# dict_outs['loss_valid_net1']=loss_valid_net1
dtype = torch.float
device = torch.device("cpu")
input_dim = 2
hidden_dim = 16
output_dim = 1
layer_dim = 1
model_one = RNN_Model(input_dim, hidden_dim, output_dim).to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model_one.parameters(),lr=1e-2, momentum=0.9, nesterov=True)
loss_list=[]
for j in range(50):
    # generate a simple addition problem (a + b = c)
    a_int = np.random.randint(largest_number/2) # int version
    a = int2binary[a_int] # binary encoding
    b_int = np.random.randint(largest_number/2) # int version
    b = int2binary[b_int] # binary encoding
    # true answer
    c_int = a_int + b_int
    c = int2binary[c_int]
#     print(a_int,b_int,c_int)
    # where we'll store our best guess (binary encoded)
    d = np.zeros_like(c)
    output=''
    for position in range(binary_dim):
        # generate input and output
        X = np.array([[a[binary_dim - position - 1],b[binary_dim - position - 1]]])
        y = np.array([[c[binary_dim - position - 1]]]).T
#         print(X.shape,y.shape)
        x_train =  torch.tensor(torch.from_numpy(X),device=device,dtype=dtype)
        y_train =   torch.tensor(torch.from_numpy(y),device=device,dtype=dtype)
        loss=train(model_one,x_train,y_train,criterion,optimizer)
        output+=str(loss[1])
        loss_list.append(loss[0])
    out=0;out_str=''
    for index,x in enumerate(reversed(output)):    
            out_str+=x
            out += int(x)*pow(2,index)   

My loss values are oscillating, I am not able to understand why?
out_rnn

Can u pls tell me whats wrong in my implementation i added sigmoid for output layer also.

Based on your code it looks you would like to learn the addition of two numbers in binary representation by passing one bit at a time. Is this correct?
Currently it seems you are not detaching the hidden state at any time, just in the initialization step of your model. Could you try to create a new hidden state for each new addition?

1 Like