Thanks @tom , I guess I am just always confused with how the dimension are proceeded and I should reread some of the literature.
Edit:
Here is an experiment showing the equality:
import torch
from torch import nn
class LSTM(nn.Module):
def __init__(self, input_dim, hidden_dim, batch_size, num_layers=1):
super(LSTM, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.batch_size = batch_size
self.num_layers = num_layers
# Define the LSTM layer
self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers)
def init_hidden(self):
# This is what we'll initialise our hidden state as
return (torch.zeros(self.num_layers, self.batch_size, self.hidden_dim),
torch.zeros(self.num_layers, self.batch_size, self.hidden_dim))
def forward(self, input):
# Forward pass through LSTM layer
# shape of lstm_out: [input_size, batch_size, hidden_dim]
# shape of self.hidden: (a, b), where a and b both
# have shape (num_layers, batch_size, hidden_dim).
hidden = self.init_hidden()
lstm_out, hidden_m1 = self.lstm(input,hidden)
self.method_1=lstm_out[-1]
self.mm_h = hidden_m1[0].squeeze()
for t_j in range(len(input)):
input_a=input[t_j,:].view(1,self.batch_size,-1)
output_b, hidden = self.lstm(input_a, hidden)
self.method_2 = output_b.squeeze()
self.mm2_h=hidden[0].squeeze()
input_size = 10
hidden_dim = 4
batch_size = 5
num_layers = 1
seq_len = 6
model = LSTM(input_size, hidden_dim, batch_size=batch_size, num_layers=num_layers)
model.eval()
X=torch.randn(seq_len,batch_size,input_size)
model(X)
torch.equal(model.method_1,model.method_2)
torch.equal(model.mm_h,model.mm2_h)
torch.equal(model.method_1,model.mm_h)
you could make it a joint input and only have one LSTM
I am unsure how I would be doing that because in the paper the state dictionary are the same and if I join the input they are going to change since they will learn from the first sentence.
so in the current implementation you have something like that in pseudo code:
...
def forward(sentence_1, sentence_2):
....
state_dict = self.lstm.state_dict()
out_1, h_1, c_1 = self.lstm(sentence_1,h,c)
#reload the initial dict so the hidden weight are not influenced by sentence_1
self.lstm.load_state_dict(state_dict)
out_2, h_2, c_2 = self.lstm(sentence_2,h,c)
....
Am I missing anything?
Edit: actually since I am not calling the backward this thing seems useless! I guess I can just join them on the feature dimension (seq x batch x feature*2) and then just split the output into two. This might be a bit trickier when I unpack it as you mentioned.
PS: I agree with your remark of doing some transformation such as concatenating… For this I just wanted to have something as close to the paper as possible.