I am trying to implement the toy example on pointer network in pytorch. But the loss is not reducing in training.
from torch import nn
from torch.autograd import Variable
class Model(nn.Module):
def __init__(self, input_dim, hidden_size, num_of_indices, blend_dim, batch_size):
super(Model, self).__init__()
self.batch_size = batch_size # B
self.input_dim = input_dim # I
self.hidden_size = hidden_size # H
self.num_of_indices = num_of_indices # N
self.blend_dim = blend_dim # D
self.encode = nn.LSTMCell(input_dim, hidden_size)
self.decode = nn.LSTMCell(input_dim, hidden_size)
self.blend_decoder = nn.Linear(hidden_size, blend_dim)
self.blend_encoder = nn.Linear(hidden_size, blend_dim)
self.scale_blend = nn.Linear(blend_dim, input_dim)
def zero_hidden_state(self):
return Variable(torch.randn([self.batch_size, self.hidden_size]).cuda())
def forward(self, inp):
hidden = self.zero_hidden_state() # BxH
cell_state = self.zero_hidden_state() # BxH
encoder_states = []
for j in range(len(inp[0])): # inp -> BxJxI
encoder_input = inp[:, j:j+1] # BxI
hidden, cell_state = self.encode(encoder_input, (hidden, cell_state))
encoder_states.append(cell_state)
decoder_state = encoder_states[-1] # BxH
pointers = []
pointer_distributions = []
start_token = 0
decoder_input = Variable(torch.Tensor([start_token] * self.batch_size) # BxI
.view(self.batch_size, self.input_dim).cuda())
hidden = self.zero_hidden_state() # BxH
cell_state = self.zero_hidden_state() # BxH
for i in range(self.num_of_indices):
hidden, cell_state = self.decode(decoder_input, (hidden, cell_state)) # BxH
decoder_blend = self.blend_decoder(cell_state) # BxD
encoder_blends = []
index_predists = []
for i in range(len(inp[0])):
encoder_blend = self.blend_encoder(encoder_states[i]) # BxD
raw_blend = encoder_blend + decoder_blend # BxD
scaled_blend = self.scale_blend(raw_blend).squeeze(1) # BxI
index_predist = scaled_blend
encoder_blends.append(encoder_blend)
index_predists.append(index_predist)
index_predistribution = torch.stack(index_predists).t() # BxJ
index_distribution = F.log_softmax(index_predistribution)
pointer_distributions.append(index_distribution)
index = index_distribution.data.max(1)[1].squeeze(1) # B
emb = embedding_lookup(inp.t(), Variable(index)) # BxB
pointer_raw = torch.diag(emb) # B
pointer = pointer_raw
pointers.append(pointer)
decoder_input = pointer.unsqueeze(1) # Bx1
#print('pointer: {}'.format(pointers))
index_distributions = torch.stack(pointer_distributions)
return index_distributions # NxBxJ
Here is the full notebook
Thanks in advance