How to avoid sending input one by one for LSTM siamese?

Hi,
I have been trying to implement the LSTM siamese for sentence similarity as introduced in the initial paper on my own but I am struggling to get the last hidden layer for each iterations without using a for loop.
h3 and h4 respectively on this diagram that come from the paper.


All the implementations I have seen (see here and there for examples) use for loop to go one by one to get the final hidden layer. It looks correct to me but this is super inefficient as when I am training it on my cpu it only uses one CPU (because of the for loop as far as I understand).
I understand that I could run multiple epoch in parallel but I’d really like to find a way to solve it in the LSTM code itself by not using the for loop which would allow the code to run faster.
My guess tell me that maybe I could do something with register hook but I am going into uncharted territory.
Could anyone give me a help or give me a hint in order to make this fast with pytorch?
Thanks in advance for any help.

1 Like

I am bumping this one as I still can’t find a good solution…
Maybe I could transpose it?
Would appreciate any help :slight_smile:

As long as you want valilla LSTM (i.e. no fancy dropout, no Graves-style intermediate gradient clipping,…) you should be able to jut call LSTM on the whole thing.
Or you could make it a joint input and only have one LSTM call, too.

Best regards

Thomas

1 Like

Thanks @tom (again!).
Would you mind explaining to me why this would work? I am unsure to understand why this would work…
Here is a schema that might help explain it (just in case)


source
Best,

I’m not sure I understand your doubts?
If you run a torch.nn.LSTM on seq x batch x feature - Tensors, it will have the same result as if you iterate over the seq (and keep track of the state (h,c)) as the unrolled loop in your first link does. The final state could be taken either from the output or the final state return (in his course and ULMFiT paper Jeremy Howard recommends using final state, the maximum, and the average over the outputs concatenated into one large feature vector).

You can also supply batches to torch.nn.LSTM and could use that for passing the two inputs in one call - but I’d recommend to check whether it improves speed. If you want to deal with sequences of different length in one batch, you would need to use packed sequences (which is straightforward, too, but you need to look after the ordering).

Best regards

Thomas

Thanks @tom , I guess I am just always confused with how the dimension are proceeded and I should reread some of the literature.

Edit:
Here is an experiment showing the equality:

import torch
from torch import nn
class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, batch_size, num_layers=1):
        super(LSTM, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.num_layers = num_layers

        # Define the LSTM layer
        self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers)

    def init_hidden(self):
        # This is what we'll initialise our hidden state as
        return (torch.zeros(self.num_layers, self.batch_size, self.hidden_dim),
                torch.zeros(self.num_layers, self.batch_size, self.hidden_dim))

    def forward(self, input):
        # Forward pass through LSTM layer
        # shape of lstm_out: [input_size, batch_size, hidden_dim]
        # shape of self.hidden: (a, b), where a and b both
        # have shape (num_layers, batch_size, hidden_dim).
        hidden = self.init_hidden()
        lstm_out, hidden_m1 = self.lstm(input,hidden)
        self.method_1=lstm_out[-1]
        self.mm_h = hidden_m1[0].squeeze()
        for t_j in range(len(input)):
            input_a=input[t_j,:].view(1,self.batch_size,-1)
            output_b, hidden = self.lstm(input_a, hidden)
        self.method_2 = output_b.squeeze()
        self.mm2_h=hidden[0].squeeze()
input_size = 10
hidden_dim = 4
batch_size = 5
num_layers = 1
seq_len = 6
model = LSTM(input_size, hidden_dim, batch_size=batch_size,  num_layers=num_layers)
model.eval()
X=torch.randn(seq_len,batch_size,input_size)
model(X)
torch.equal(model.method_1,model.method_2)
torch.equal(model.mm_h,model.mm2_h)
torch.equal(model.method_1,model.mm_h)

you could make it a joint input and only have one LSTM

I am unsure how I would be doing that because in the paper the state dictionary are the same and if I join the input they are going to change since they will learn from the first sentence.
so in the current implementation you have something like that in pseudo code:

...
def forward(sentence_1, sentence_2):
    ....
    state_dict = self.lstm.state_dict()
    out_1, h_1, c_1 = self.lstm(sentence_1,h,c)
    #reload the initial dict so the hidden weight are not influenced by sentence_1
    self.lstm.load_state_dict(state_dict)
    out_2, h_2, c_2 = self.lstm(sentence_2,h,c)
    ....

Am I missing anything?
Edit: actually since I am not calling the backward this thing seems useless! I guess I can just join them on the feature dimension (seq x batch x feature*2) and then just split the output into two. This might be a bit trickier when I unpack it as you mentioned.
PS: I agree with your remark of doing some transformation such as concatenating… For this I just wanted to have something as close to the paper as possible.

LSTM is a sequential model, you can’t parallel it, for loop implementation should be correct
You can run 2/more LSTM in parallel though, with multiple threads or something

I don’t think you actually need that (but you are missing a self in your forward signature definition).

No, you join them in the batch dimension, or rather use packed sequences for that:

lstm = torch.nn.LSTM(300, 100)
sentence_1 = torch.randn(15,300)
sentence_2 = torch.randn(16,300)
def lstm_twice_1(sentence_1, sentence_2):
    if len(sentence_2) > len(sentence_1):  # order them for pack_sequence, doesn't matter as it's symmetric after the LSTM
      sentences = torch.nn.utils.rnn.pack_sequence([sentence_2, sentence_1])
    else:
      sentences = torch.nn.utils.rnn.pack_sequence([sentence_1, sentence_2])

    out, (h_n, c_n) = lstm(sentences)
    dist = torch.dist(h_n[0,0], h_n[0,1])
    return dist
    
def lstm_twice_2(sentence_1, sentence_2):
    out1, (h_n_1, c_n_1) = lstm(sentence_1.unsqueeze(1))
    out2, (h_n_2, c_n_2) = lstm(sentence_2.unsqueeze(1))
    dist = torch.dist(h_n_1[0, 0], h_n_2[0, 0])
    return dist

assert torch.allclose(lstm_twice_1(sentence_1, sentence_2), lstm_twice_2(sentence_1, sentence_2))
%timeit lstm_twice_1(sentence_1, sentence_2)
%timeit lstm_twice_2(sentence_1, sentence_2)

gives

100 loops, best of 3: 14.9 ms per loop
10 loops, best of 3: 24.9 ms per loop

This lazy IPython benchmark seems to suggest that

  • you can use the batch dimension to run two things through the same LSTM,
  • it’s a faster, too.

It becomes more tricky if you want batches of sentence pairs because you need to sort and keep track of the indices. But by then the relative speedup of variant 1 vs. variant 2 might become a lot smaller, too.

I’m not sure why @lugiavn thinks you can’t do that.

Best regards

Thomas

1 Like

I meant, you can’t parallel the sequential step in a LSTM (source code reference that OP cited), because it’s … well sequential. I’m sure torch.nn.LSTM is also doing a for loop inside
But like I said, if it’s 2 separate LSTM, you can just simply do it on 2 threads or something.

Note that I didn’t comment on that batching thing. Now that I glance at it. I realize if you guys are talking about running the same LSTM model, then sure you should always process a batch of dozens of sentences or something, same as how you deal with images. Nobody runs LSTM on 1 or 2 sentences at a time. Also note that the gain here is because larger matrix computation is more efficient, it’s likely still using 1 CPU core (unless torch is using some library to auto-parallel matrix multiplication on CPU)

Hi ,

I’ve some confusion regarding the ways implementation of the siamese network.
Should we have to update lstm’s hidden_state and cell_state while iterating in training or not using it is also right.

Thank you