Pointer network example in pytorch

I am trying to implement the toy example on pointer network in pytorch. But the loss is not reducing in training.

from torch import nn
from torch.autograd import Variable
class Model(nn.Module):
  def __init__(self, input_dim, hidden_size, num_of_indices, blend_dim, batch_size):
    super(Model, self).__init__()
    
    self.batch_size = batch_size               # B
    self.input_dim = input_dim                 # I
    self.hidden_size = hidden_size             # H
    self.num_of_indices = num_of_indices       # N
    self.blend_dim = blend_dim                 # D
            
    self.encode = nn.LSTMCell(input_dim, hidden_size)
    self.decode = nn.LSTMCell(input_dim, hidden_size)
    self.blend_decoder = nn.Linear(hidden_size, blend_dim)
    self.blend_encoder = nn.Linear(hidden_size, blend_dim)
    self.scale_blend = nn.Linear(blend_dim, input_dim)
    
  def zero_hidden_state(self):
    return Variable(torch.randn([self.batch_size, self.hidden_size]).cuda())
    
  def forward(self, inp):
    hidden = self.zero_hidden_state()                                            # BxH
    cell_state = self.zero_hidden_state()                                        # BxH
    encoder_states = []
    for j in range(len(inp[0])):                                          # inp -> BxJxI
        encoder_input = inp[:, j:j+1]                                            # BxI
        hidden, cell_state = self.encode(encoder_input, (hidden, cell_state)) 
        encoder_states.append(cell_state)
        
    decoder_state = encoder_states[-1]                       # BxH
    pointers = []
    pointer_distributions = []
    
    start_token = 0
    decoder_input = Variable(torch.Tensor([start_token] * self.batch_size)        # BxI
                             .view(self.batch_size, self.input_dim).cuda())

    hidden = self.zero_hidden_state()                                         # BxH
    cell_state = self.zero_hidden_state()                                     # BxH
    for i in range(self.num_of_indices):
        hidden, cell_state = self.decode(decoder_input, (hidden, cell_state))     # BxH
        
        decoder_blend = self.blend_decoder(cell_state)                            # BxD
        encoder_blends = []
        index_predists = []
        for i in range(len(inp[0])):
            encoder_blend = self.blend_encoder(encoder_states[i])                  # BxD
            raw_blend = encoder_blend + decoder_blend                              # BxD
            scaled_blend = self.scale_blend(raw_blend).squeeze(1)                  # BxI
            
            index_predist = scaled_blend
            
            encoder_blends.append(encoder_blend)
            index_predists.append(index_predist)
            
        index_predistribution = torch.stack(index_predists).t()                    # BxJ
        index_distribution = F.log_softmax(index_predistribution)
        pointer_distributions.append(index_distribution)                          
        index = index_distribution.data.max(1)[1].squeeze(1)                       # B

        emb = embedding_lookup(inp.t(), Variable(index))                           # BxB
        pointer_raw = torch.diag(emb)                                              # B
        pointer = pointer_raw
        pointers.append(pointer)
        decoder_input = pointer.unsqueeze(1)                                       # Bx1

        #print('pointer: {}'.format(pointers))
    index_distributions = torch.stack(pointer_distributions)                    
    return index_distributions                                                     # NxBxJ

Here is the full notebook

Thanks in advance

do you have a specific question? or are you looking for general suggestions? (usually it helps to ask a specific question)

I had included the zero inititialization for decoder inside the loop - zeroing it for every step. I have corrected it. The training is occuring but at extremely slow pace. A specific question would be, where to look to know why training is too slow? Any general suggestions would be helpful.

train(4000, model, train_batches)

epoch: 0 -- loss: 4.162579536437988
epoch: 100 -- loss: 4.112975120544434
epoch: 200 -- loss: 4.112966060638428
epoch: 300 -- loss: 4.112963676452637
epoch: 400 -- loss: 4.1129631996154785
epoch: 500 -- loss: 4.1129631996154785
epoch: 600 -- loss: 4.11296272277832
epoch: 700 -- loss: 4.112962245941162
epoch: 800 -- loss: 4.112962245941162
epoch: 900 -- loss: 4.112961769104004
epoch: 1000 -- loss: 4.112961769104004
epoch: 1100 -- loss: 4.112961292266846
epoch: 1200 -- loss: 4.112961292266846
epoch: 1300 -- loss: 4.112961292266846
epoch: 1400 -- loss: 4.112961292266846
epoch: 1500 -- loss: 4.1129608154296875
epoch: 1600 -- loss: 4.1129608154296875
epoch: 1700 -- loss: 4.1129608154296875
epoch: 1800 -- loss: 4.1129608154296875
epoch: 1900 -- loss: 4.1129608154296875
epoch: 2000 -- loss: 4.112960338592529
epoch: 2100 -- loss: 4.112960338592529
epoch: 2200 -- loss: 4.112960338592529
epoch: 2300 -- loss: 4.112959861755371
epoch: 2400 -- loss: 4.112959861755371
epoch: 2500 -- loss: 4.112959384918213
epoch: 2600 -- loss: 4.112959384918213
epoch: 2700 -- loss: 4.112958908081055
epoch: 2800 -- loss: 4.112958908081055
epoch: 2900 -- loss: 4.1129584312438965
epoch: 3000 -- loss: 4.1129584312438965
epoch: 3100 -- loss: 4.112957954406738
epoch: 3200 -- loss: 4.112957954406738
epoch: 3300 -- loss: 4.112957954406738
epoch: 3400 -- loss: 4.112957954406738
epoch: 3500 -- loss: 4.11295747756958
epoch: 3600 -- loss: 4.11295747756958
epoch: 3700 -- loss: 4.11295747756958
epoch: 3800 -- loss: 4.112957000732422
epoch: 3900 -- loss: 4.112957000732422

you can fire up a Python profiler and see where bottlenecks are.

If you are using CUDA too, export this environment variable before running python so that the profiler is correct:

export CUDA_LAUNCH_BLOCKING=1
python my_program.py
1 Like

Thanks. I will try this out. This is related to performance bottlenecks right, will it also be helpful to understand why loss/epoch is not coming down?

Is there a way to print the graph of the computation? to see how the gradient propagates back?

I think your problem is on the line:

raw_blend = encoder_blend + decoder_blend # BxD

You forgot the activation function and end up learning only linearity. should be:

raw_blend = torch.nn.functional.elu(encoder_blend + decoder_blend)

1 Like