What is the correct way to implement attention for long sequences in pytorch ?
I have the following model, called CRNN:
[Several convolutions]
|*
V
[RNN1]
|*
V
[RNN2]
(*) the input to the RNNs is not passed as hidden state, but it it instead fed directly
to the RNN, the hidden sate is set to zero at the beginning.
The goal of the model is extract text from an image.
I would like to add an attention mechanism to the model between RNN1 and RNN2 I have tried several approaches but none of the seems to work: the sequences are large, so the approaches found no the web (which are mainly focused on translation) either cause to a big slow down in the training process or make the program crash with an out of memory error from the graphics card.
Approach 1
First I based the code on several articles:
- https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation-batched.ipynb
- https://fehiepsi.github.io/blog/grapheme-to-phoneme/
- https://github.com/AuCson/PyTorch-Batch-Attention-Seq2seq/blob/master/attentionRNN.py
It is not a “copy and paste” but the implementation Idea is the same - go over the sequence in a for loop and apply attention on each iteration.
This solution is not working since the sequences are pretty big, usually the sequence length amounts to 800
class Attention(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim*2, dim, bias=False)
def forward(self, x, context):
assert x.size(0) == context.size(0), \
f' x: {x.size()} ctx : {context.size()} I' # x: batch x dim
assert x.size(1) == context.size(2), \
f' x: {x.size()} ctx : {context.size()} II' # context: batch x seq x dim
attn = F.softmax(
context.bmm(
x.unsqueeze(2) # bsz x dim x 1
) # bsz x seq x 1
.squeeze(2) # bsz x seq
, dim = 1)
weighted_context = attn.unsqueeze(1) # bsz x 1 x seq
weighted_context = weighted_context.bmm(context) # bsz x 1 x dim
weighted_context = weighted_context.squeeze(1) # bsz x dim
o = self.linear(torch.cat((x, weighted_context), 1))
return F.tanh(o)
class AttnLSTM(nn.Module):
def __init__(self, d_inp, d_hidden):
super().__init__()
#self.rnn = nn.LSTM(d_inp, d_hidden)
self.rnn = nn.GRU(d_inp, d_hidden)
self.d_hidden = d_hidden
self.attn = Attention(d_hidden)
def init_hidden(self, bsz):
cuda = True
tt = torch.cuda if cuda else torch # use cuda tensor or not
if isinstance(self.rnn, nn.LSTM):
# create initial hidden state and initial cell state
h = Variable(tt.FloatTensor(1, bsz, self.d_hidden).zero_())
c = Variable(tt.FloatTensor(1, bsz, self.d_hidden).zero_())
return (h, c)
else: # GRU
h = Variable(tt.FloatTensor(1, bsz, self.d_hidden).zero_())
return h
def forward(self, xs, context):
# xs ~ seq x batch x dim
o = []
hidden = self.init_hidden(xs.size(1))
for x in xs:
res, hidden = self.rnn(x.unsqueeze(0), hidden)
o.append(self.attn(res.squeeze(0), context))
return torch.stack(o, 0)
class AttentionLSTMSlow(nn.Module):
def __init__(self, nin, nh, nout):
super().__init__(*args)
self.rnn1 = nn.LSTM(nin, nh, True)
self.rnn2 = AttnLSTM(nh*2, nh*2)
self.fc2 = nn.Linear(nh*2, nout)
...
def forward(self, inp):
...
rnn1, _ = self.rnn1(inp, hidden)
rnn2 = self.rnn2(rnn1, context=rnn1.transpose(0, 1))
return self.fc2(rnn2)
...
Approach 2.1
I tried to implement attention using matrix BLAS matrix operations but it
crashed with an out of memory error, which I suppose, is caused by autograd trying to
pass gradients through the scores
variable
class AttentionLSTM3(nn.Module):
def __init__(self, nin, nh, nout):
super().__init__(*args)
self.rnn1 = nn.LSTM(nin, nh, True)
self.rnn2 = nn.LSTM(nh*2, nh, True)
self.w = nn.Linear(128, 128)
def attention(self, ctx, x):
# ctx/x ~ seq x bs x dim/dim'
ctx = ctx.transpose(0, 1)
x = x.transpose(0, 1)
# ctx/x ~ bs x seq x dim/dim'
scores = self.w(ctx).bmm(ctx.transpose(1, 2)) # bs x seq x seq
scores = F.softmax(scores, dim=1)
print('SC', scores.size())
res = scores.bmm(x) # bs x eq x dim'
print('RE', res.size())
return res
def forward(self, inp):
...
rnn1 = self.rnn1(inp, hidden)
att = self.attention(ctx=inp, x=rnn1)
fc2, rnn2 = self._forward(1, att)
return fc2
....
Approach 2.2
I tried to minimize the “bottleneck” so i implemented windowing attention, which makes the model attend to only part of the sequence, but no matter how small the window size was it still crashes.
class WinAttentionLSTM(nn.Module):
def __init__(self, nin, nh, nout):
super().__init__(*args)
self.rnn1 = nn.LSTM(nin, nh, True)
self.rnn2 = nn.LSTM(nh*2, nh, True)
s = 512
self.w = nn.Linear(s, s)
def scores(self, ctx):
# ctx ~ seq_bch x bs x dim
ctx = ctx.transpose(0, 1)
# ctx ~ bs x seq_bc x dim
scores = self.w(ctx).bmm(ctx.transpose(1, 2)) # bs x seq_bc x seq_bc
scores = F.softmax(scores, dim=1)
return scores
def forward(self, inp, win_size=40):
...
rnn1 = self.rnn1(inp, hidden)
weighted = []
for xs in rnn1.split(win_size, dim=0):
scores = self.scores(ctx=xs)
w = scores.bmm(xs.transpose(0, 1))
weighted.append(w)
print('OK')
res = torch.cat(weighted, dim=1)
fc2, rnn2 = self._forward(1, res)
return fc2
Questions
So the final questions are:
- How to implement attention efficiently , the slow down in training caused by Approach 1 is very notable 0.2 sec per 10 iterations vs 2.4 sec per 10 iterations, this means that instead of 4 hours of training I should train the model for nearly 2 days and Approaches 2.* do not fit in the 8GB of a GTX1070.
- Why Approach 2.2 crashes even if I set win_size = 2 is my intuition about gradients correct ?
Thanks in advance,
Arseny
PS.
I omit some code because I use a custom base class which I find too clumsy to include in this question, if the full code is needed the it may be found here https://gist.github.com/Arseny-N/b448daa7f4840ba12850dafc25215333