I wrote a Seq2Seq model for conversation generation but the speed is extremely slow.
I saw this post, and test my model as @apaszke said
torch.cuda.synchronize()
start = # get start time
output = model(input)
torch.cuda.synchronize()
end = # get end time
When I run 10 batches on the Seq2Seq with dynamic attention, the result is like:
Model time one batch: 6.50730013847
Model time one batch: 5.17414689064
Model time one batch: 4.81271314621
Model time one batch: 4.43320679665
Model time one batch: 4.23180413246
Model time one batch: 4.5174510479
Model time one batch: 8.60860896111
Model time one batch: 4.29604315758
Model time one batch: 4.30749702454
Model time one batch: 4.46091389656
Model time one batch: 4.33084321022
dynamic attention means I need to compute context vector in each decode time step. As you can see, I use two for loops to compute attn_weights at each time step.
class Attn(nn.Module):
def __init__(self, input_size, attn_size, cuda=True):
super(Attn, self).__init__()
self.input_size = input_size
self.attn_size = attn_size
self.cuda = cuda
self.attn = nn.Linear(self.input_size * 2, attn_size)
self.mlp = nn.Sequential(
nn.Linear(attn_size, attn_size),
nn.Tanh()
)
self.v = nn.Parameter(torch.FloatTensor(1, attn_size))
init.xavier_uniform(self.v)
def forward(self, state, encoder_outputs, encoder_input_lengths):
"""
:Parameters:
:state: decoder last time step state, shape=[num_layers, B, hidden_size]
:encoder_outputs: encoder outputs of all time steps, shape=[B, T, hidden_size]
:encoder_input_lengths: List, [B_enc]
:Return:
"""
this_batch_size = state.size(1)
max_len = encoder_outputs.size(1)
attn_energies = Variable(torch.zeros(this_batch_size, max_len))
if self.cuda:
attn_energies = attn_energies.cuda()
for i in range(this_batch_size):
for j in range(encoder_input_lengths[i]):
attn_energies[i, j] = self.score(state[-1][i], encoder_outputs[i, j])
attn_mask = attn_energies.ne(0).float()
attn_exp = torch.exp(attn_energies) * attn_mask
attn_weights = attn_exp / torch.cat([attn_exp.sum(1)]*max_len, 1)
return attn_weights
def score(self, si, hj):
"""
:Parameters:
:si: time=i-1,decoder state
:hj: time=j, encoder output state
"""
# v*tanh(W*concat(si, hj))
inp = torch.cat((si, hj)).unsqueeze(0)
energy = self.attn(inp)
energy = self.mlp(energy) #F.tanh(energy)
energy = self.v.dot(energy)
return energy
static attention means context vector is unchanged in each decode time step, so there is no need to compute multiple times.
Note: I think the reported time is just for forward(), if doing backward(), maybe much slower than that. : (
And 10 batches on the Seq2Seq with static attention:
Model time one batch: 1.55489993095
Model time one batch: 0.443991184235
Model time one batch: 0.185837030411
Model time one batch: 0.196111917496
Model time one batch: 0.193861961365
Model time one batch: 0.194068908691
Model time one batch: 0.190461874008
Model time one batch: 0.18402504921
Model time one batch: 0.186547040939
Model time one batch: 0.183899879456
Model time one batch: 0.191169023514
And 10 batches on the raw Seq2Seq (without any attention context):
Model time one batch: 1.13855099678
Model time one batch: 0.356266021729
Model time one batch: 0.185835123062
Model time one batch: 0.170114040375
Model time one batch: 0.170575141907
Model time one batch: 0.171154975891
Model time one batch: 0.17102599144
Model time one batch: 0.185311079025
Model time one batch: 0.166770935059
Model time one batch: 0.163444042206
Model time one batch: 0.169273138046
-
The raw Seq2Seq needs 0.17 average on batches, it also slower than tensorflow. So I wonder how I can detect the bottleneck of my model?
-
According to the comparision, it seems like dynamic attention is time consuming. How can I improve this part of code to make it run faster?
Thank!