Efficient attention implementation for long sequences

What is the correct way to implement attention for long sequences in pytorch ?

I have the following model, called CRNN:

[Several convolutions]

(*) the input to the RNNs is not passed as hidden state, but it it instead fed directly
to the RNN, the hidden sate is set to zero at the beginning.

The goal of the model is extract text from an image.

I would like to add an attention mechanism to the model between RNN1 and RNN2 I have tried several approaches but none of the seems to work: the sequences are large, so the approaches found no the web (which are mainly focused on translation) either cause to a big slow down in the training process or make the program crash with an out of memory error from the graphics card.

Approach 1

First I based the code on several articles:

It is not a “copy and paste” but the implementation Idea is the same - go over the sequence in a for loop and apply attention on each iteration.

This solution is not working since the sequences are pretty big, usually the sequence length amounts to 800

class Attention(nn.Module):
    def __init__(self, dim):
        self.linear = nn.Linear(dim*2, dim, bias=False)
    def forward(self, x, context):
        assert x.size(0) == context.size(0), \
            f' x: {x.size()} ctx : {context.size()} I' # x: batch x dim
        assert x.size(1) == context.size(2), \
            f' x: {x.size()} ctx : {context.size()} II' # context: batch x seq x dim
        attn = F.softmax(
                x.unsqueeze(2) # bsz x dim x 1
            )                  # bsz x seq x 1
            .squeeze(2)        # bsz x seq
            , dim = 1)
        weighted_context = attn.unsqueeze(1) # bsz x 1 x seq 
        weighted_context = weighted_context.bmm(context)             # bsz x 1 x dim
        weighted_context = weighted_context.squeeze(1)               # bsz x dim
        o = self.linear(torch.cat((x, weighted_context), 1))
        return F.tanh(o)

class AttnLSTM(nn.Module):

    def __init__(self, d_inp, d_hidden):
        #self.rnn = nn.LSTM(d_inp, d_hidden)
        self.rnn = nn.GRU(d_inp, d_hidden)
        self.d_hidden = d_hidden
        self.attn = Attention(d_hidden)

    def init_hidden(self, bsz):
        cuda = True
        tt = torch.cuda if cuda else torch  # use cuda tensor or not
        if isinstance(self.rnn, nn.LSTM):
            # create initial hidden state and initial cell state
            h = Variable(tt.FloatTensor(1, bsz, self.d_hidden).zero_())
            c = Variable(tt.FloatTensor(1, bsz, self.d_hidden).zero_())
            return (h, c)
        else: # GRU
            h = Variable(tt.FloatTensor(1, bsz, self.d_hidden).zero_())
            return h

    def forward(self, xs, context):
        # xs ~ seq x batch x dim
        o = []
        hidden = self.init_hidden(xs.size(1))
        for x in xs:
            res, hidden = self.rnn(x.unsqueeze(0), hidden)
            o.append(self.attn(res.squeeze(0), context))
        return torch.stack(o, 0)
class AttentionLSTMSlow(nn.Module):
    def __init__(self, nin, nh, nout):
        self.rnn1 = nn.LSTM(nin, nh, True)
        self.rnn2 = AttnLSTM(nh*2, nh*2)
        self.fc2 = nn.Linear(nh*2, nout)
    def forward(self, inp):
        rnn1, _ = self.rnn1(inp, hidden)
        rnn2 = self.rnn2(rnn1, context=rnn1.transpose(0, 1))
        return self.fc2(rnn2)

Approach 2.1

I tried to implement attention using matrix BLAS matrix operations but it
crashed with an out of memory error, which I suppose, is caused by autograd trying to
pass gradients through the scores variable

class AttentionLSTM3(nn.Module):
    def __init__(self, nin, nh, nout):
        self.rnn1 = nn.LSTM(nin, nh, True)
        self.rnn2 = nn.LSTM(nh*2, nh, True)

        self.w = nn.Linear(128, 128)
    def attention(self, ctx, x):
        #  ctx/x ~ seq x bs x dim/dim'
        ctx = ctx.transpose(0, 1)
        x = x.transpose(0, 1)
        # ctx/x ~ bs x seq x dim/dim'  
        scores = self.w(ctx).bmm(ctx.transpose(1, 2)) # bs x seq x seq
        scores = F.softmax(scores, dim=1)
        print('SC', scores.size())
        res = scores.bmm(x) # bs x eq x dim'
        print('RE', res.size())
        return res
    def forward(self, inp):
        rnn1 = self.rnn1(inp, hidden)
        att = self.attention(ctx=inp, x=rnn1)
        fc2, rnn2 = self._forward(1, att)
        return fc2

Approach 2.2

I tried to minimize the “bottleneck” so i implemented windowing attention, which makes the model attend to only part of the sequence, but no matter how small the window size was it still crashes.

class WinAttentionLSTM(nn.Module):
    def __init__(self, nin, nh, nout):
        self.rnn1 = nn.LSTM(nin, nh, True)
        self.rnn2 = nn.LSTM(nh*2, nh, True)
        s = 512
        self.w = nn.Linear(s, s)
    def scores(self, ctx):
        # ctx ~ seq_bch x bs x dim
        ctx = ctx.transpose(0, 1)
        # ctx ~ bs x seq_bc x dim
        scores = self.w(ctx).bmm(ctx.transpose(1, 2)) # bs x seq_bc x seq_bc
        scores = F.softmax(scores, dim=1)
        return scores
    def forward(self, inp, win_size=40):
        rnn1 = self.rnn1(inp, hidden)
        weighted = []
        for xs in rnn1.split(win_size, dim=0):            
            scores = self.scores(ctx=xs)
            w = scores.bmm(xs.transpose(0, 1))
        res = torch.cat(weighted, dim=1)
        fc2, rnn2 = self._forward(1, res)
        return fc2


So the final questions are:

  1. How to implement attention efficiently , the slow down in training caused by Approach 1 is very notable 0.2 sec per 10 iterations vs 2.4 sec per 10 iterations, this means that instead of 4 hours of training I should train the model for nearly 2 days and Approaches 2.* do not fit in the 8GB of a GTX1070.
  2. Why Approach 2.2 crashes even if I set win_size = 2 is my intuition about gradients correct ?

Thanks in advance,

I omit some code because I use a custom base class which I find too clumsy to include in this question, if the full code is needed the it may be found here https://gist.github.com/Arseny-N/b448daa7f4840ba12850dafc25215333