I am following NMT implementations in:
Pytorch: NLP From Scratch: Translation with a Sequence to Sequence Network and Attention — PyTorch Tutorials 1.7.1 documentation
Tensorflow: Neural machine translation with attention | TensorFlow Core
I found a marked difference in the decoder implementation:
In Tensorflow implementation, attention weights are calculated using hidden states (query) and encoder output (values) and then the context vector is computed using attention weights and encoder output.
In Pytorch implementation, attention weights are calculated using hidden state and embedded decoder input. Then the context vector (named attn_applied here) is computed using attention weights and encoder output.
Can anyone explain the difference between the approaches here or am I missing something out?
In Tensorflow:
class BahdanauAttention(tf.keras.layers.Layer):
def __init__(self,units):
super(BahdanauAttention, self).__init__()
self.W1 = tf.keras.layers.Dense(units)
self.W2 = tf.keras.layers.Dense(units)
self.V = tf.keras.layers.Dense(1)
def call(self,query,values):
query_with_time_axis = tf.expand_dims(query,1)
score = self.V(tf.nn.tanh(self.W1(query_with_time_axis)+self.W2(values)))
attention_weights = tf.nn.softmax(score,axis=1)
context_vector = attention_weights*values
context_vector = tf.reduce_sum(context_vector,axis=1)
return context_vector , attention_weights
class Decoder(tf.keras.Model):
def __init__(self, vocab_size,embedding_dim,dec_units,batch_size):
super(Decoder,self).__init__()
self.batch_size = batch_size
self.embedding_dim = embedding_dim
self.vocab_size = vocab_size
self.embedding = tf.keras.layers.Embedding(vocab_size,embedding_dim)
self.gru = tf.keras.layers.GRU(dec_units,return_sequences=True,return_state=True,\
recurrent_initializer='glorot_uniform')
self.fc = tf.keras.layers.Dense(vocab_size)
self.attention = BahdanauAttention(dec_units)
def call(self,x,hidden,encoder_out):
x = self.embedding(x)
context_vector, attention_weights = self.attention(hidden, encoder_out)
x = tf.concat([tf.expand_dims(context_vector,1),x],axis=-1)
output,state = self.gru(x)
output = tf.reshape(output,(-1,output.shape[-1]))
x = self.fc(output)
return x,state,attention_weights
In Pytorch:
class AttnDecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
super(AttnDecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.max_length = max_length
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
self.out = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input, hidden, encoder_outputs):
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)
attn_weights = F.softmax(
self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
attn_applied = torch.bmm(attn_weights.unsqueeze(0),
encoder_outputs.unsqueeze(0))
output = torch.cat((embedded[0], attn_applied[0]), 1)
output = self.attn_combine(output).unsqueeze(0)
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = F.log_softmax(self.out(output[0]), dim=1)
return output, hidden, attn_weights
def initHidden(self):
return torch.zeros(1, 1, self.hidden_size, device=device)