Seq2Seq with attention extremely slow

I wrote a Seq2Seq model for conversation generation but the speed is extremely slow.
I saw this post, and test my model as @apaszke said

torch.cuda.synchronize()
start = # get start time
output = model(input)
torch.cuda.synchronize()
end = # get end time

When I run 10 batches on the Seq2Seq with dynamic attention, the result is like:

Model time one batch: 6.50730013847
Model time one batch: 5.17414689064
Model time one batch: 4.81271314621
Model time one batch: 4.43320679665
Model time one batch: 4.23180413246
Model time one batch: 4.5174510479
Model time one batch: 8.60860896111
Model time one batch: 4.29604315758
Model time one batch: 4.30749702454
Model time one batch: 4.46091389656
Model time one batch: 4.33084321022

dynamic attention means I need to compute context vector in each decode time step. As you can see, I use two for loops to compute attn_weights at each time step.

class Attn(nn.Module):
    def __init__(self, input_size, attn_size, cuda=True):
        super(Attn, self).__init__()
        self.input_size = input_size
        self.attn_size = attn_size
        self.cuda = cuda

        self.attn = nn.Linear(self.input_size * 2, attn_size)
        self.mlp = nn.Sequential(
                    nn.Linear(attn_size, attn_size),
                    nn.Tanh()
                    )
        self.v = nn.Parameter(torch.FloatTensor(1, attn_size))
        init.xavier_uniform(self.v) 

    def forward(self, state, encoder_outputs, encoder_input_lengths):
        """
        :Parameters:
        :state: decoder last time step state, shape=[num_layers, B, hidden_size]
        :encoder_outputs: encoder outputs of all time steps, shape=[B, T, hidden_size]
        :encoder_input_lengths: List, [B_enc] 
        :Return:
        """
        this_batch_size = state.size(1)
        max_len = encoder_outputs.size(1)

        attn_energies = Variable(torch.zeros(this_batch_size, max_len))
        if self.cuda:
            attn_energies = attn_energies.cuda()

        for i in range(this_batch_size):
            for j in range(encoder_input_lengths[i]):
                attn_energies[i, j] = self.score(state[-1][i], encoder_outputs[i, j])
        attn_mask = attn_energies.ne(0).float()
        attn_exp = torch.exp(attn_energies) * attn_mask
        attn_weights = attn_exp / torch.cat([attn_exp.sum(1)]*max_len, 1)
        return attn_weights
    
    def score(self, si, hj):
        """
        :Parameters:
        :si: time=i-1,decoder state
        :hj: time=j, encoder output state
        """
        # v*tanh(W*concat(si, hj))
        inp = torch.cat((si, hj)).unsqueeze(0)
        energy = self.attn(inp)
        energy = self.mlp(energy) #F.tanh(energy)
        energy = self.v.dot(energy)
        return energy

static attention means context vector is unchanged in each decode time step, so there is no need to compute multiple times.
Note: I think the reported time is just for forward(), if doing backward(), maybe much slower than that. : (

And 10 batches on the Seq2Seq with static attention:

Model time one batch: 1.55489993095
Model time one batch: 0.443991184235
Model time one batch: 0.185837030411
Model time one batch: 0.196111917496
Model time one batch: 0.193861961365
Model time one batch: 0.194068908691
Model time one batch: 0.190461874008
Model time one batch: 0.18402504921
Model time one batch: 0.186547040939
Model time one batch: 0.183899879456
Model time one batch: 0.191169023514

And 10 batches on the raw Seq2Seq (without any attention context):

Model time one batch: 1.13855099678
Model time one batch: 0.356266021729
Model time one batch: 0.185835123062
Model time one batch: 0.170114040375
Model time one batch: 0.170575141907
Model time one batch: 0.171154975891
Model time one batch: 0.17102599144
Model time one batch: 0.185311079025
Model time one batch: 0.166770935059
Model time one batch: 0.163444042206
Model time one batch: 0.169273138046
  1. The raw Seq2Seq needs 0.17 average on batches, it also slower than tensorflow. So I wonder how I can detect the bottleneck of my model?

  2. According to the comparision, it seems like dynamic attention is time consuming. How can I improve this part of code to make it run faster?
    Thank!

3 Likes

Did you manage to speed this up? I’m facing the exact same problem. :cold_sweat:

Edit 1:
Speeded it up by vectorizing the operations using torch.baddbmm instead of the double for loops.

Used this fact a lot:
For a batch-first scenario,

x = torch.randn(10, 4, 8) # encoder inputs (batch * word * ?)
y = torch.randn(10, 8, 1) # hidden state

# this for loop version
aa = torch.zeros(10, 4)
for batch in range(x.size()[0]):
    for word in range(x.size()[1]):
        aa[batch,word] = x[batch,word].dot(y[batch,:].squeeze(1))

# is equivalent to this vector version
bb = torch.baddbmm(torch.zeros(4,1), x, y).squeeze(2)

# verify if they are equal
aa == bb
    # 1     1     1     1
    # 1     1     1     1
    # 1     1     1     1
    # 1     1     1     1
    # 1     1     1     1
    # 1     1     1     1
    # 1     1     1     1
    # 1     1     1     1
    # 1     1     1     1
    # 1     1     1     1
1 Like

Hi ixaxaar,

I also have exactly the same problem. So I’m trying to make it faster with baddbmm, though, I have a trouble.
In my case, encoder_outputs has (length x batch_size x hidden_vector_size) and hidden has (1 x batch_size x hidden_vector_size) in forward function. I mean they are not batch first.

class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()

        self.method = method
        self.hidden_size = hidden_size

        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)

        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(1, hidden_size))

    def forward(self, hidden, encoder_outputs):
        max_len = encoder_outputs.size(0)
        this_batch_size = encoder_outputs.size(1)

        # Create variable to store attention energies
        attn_energies = Variable(torch.zeros(this_batch_size, max_len))  # B x S

        if USE_CUDA:
            attn_energies = attn_energies.cuda()

        # For each batch of encoder outputs
        for b in range(this_batch_size):
            # Calculate energy for each encoder output
            for i in range(max_len):
                attn_energies[b, i] = self.score(hidden[:, b], encoder_outputs[i, b].unsqueeze(0))
        #try to replace for loop with baddbmm
        bb = torch.baddbmm(Variable(torch.zeros(max_len, 1).cuda()), encoder_outputs.transpose(0,1), hidden.transpose(0,1)).squeeze(2)
        # Normalize energies to weights in range 0 to 1, resize to 1 x B x S
        return F.softmax(attn_energies).unsqueeze(1)

    def score(self, hidden, encoder_output):

        if self.method == 'dot':
            energy = hidden.dot(encoder_output)
            return energy

        elif self.method == 'general':
            energy = self.attn(encoder_output)
            energy = hidden.dot(energy)
            return energy

        elif self.method == 'concat':
            energy = self.attn(torch.cat((hidden, encoder_output), 1))
            energy = self.v.dot(energy)
            return energy

Instead of for loop, I’m trying to use baddbmm like this.

        for b in range(this_batch_size):
            # Calculate energy for each encoder output
            for i in range(max_len):
                attn_energies[b, i] = self.score(hidden[:, b], encoder_outputs[i, b].unsqueeze(0))

bb = torch.baddbmm(Variable(torch.zeros(max_len, 1).cuda()), encoder_outputs.transpose(0,1), hidden.transpose(0,1)).squeeze(2)

But I got following error and it’s not working.

RuntimeError: expected 3D tensor at /b/wheel/pytorch-src/torch/lib/THC/generic/THCTensorMathBlas.cu:437

If you could figure out how to apply baddmm for my case, could you help me?

You could probably use torch.bmm instead

1 Like

Hi, I have the same problem as you. Were you able to speed it up ?

The only way to speed up ops on torch is to use vectorized ops.
As @cakeeatingpolarbear mentioned, you could use torch.bmm as well as others to do so.

Perhaps you’d also like to check out tensor comprehensions if you have very specialized ops not natively available on pytorch.

Have a look at https://github.com/kevinlu1211/PytorchLuongAttention

2 Likes

I am facing the same problem as well. Was anyone able to speed it up?