Pytorch Implementing Simple Attention using Dummy data

Hi I am trying to implement simple/General attention in Pytorch , So far the model seems to working , but what i am intersted in doing is getting the attention weights , so that i can visualize it .

Here’s what i am doing , creating a dummy sequence data , the 5th sequence is set as the target , so all the model needs to do is to understand that the 5th sequence in the data is the target and give a higher attention weight to the 5th sequence .

Strangely when i try to plot my attention weight the highest weight is observed in the last sequence! .

So i would like to know two things .

  1. Given the code , is my implementation of attention Correct ?
  2. If its right, why is the attention weights shifted to the last sequence?

Here’s the entire code that i am playing with (Hope its readable and understandable)

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import math
import random
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pandas as pd

INPUT_DIMS = 10
TIME_STEPS = 10
ATTENTION_COL = 5
def get_data_recurrent(n, time_steps, input_dim, attention_column=10):
    """
    Data generation. x is purely random except that it's first value equals the target y.
    In practice, the network should learn that the target = x[attention_column].
    Therefore, most of its attention should be focused on the value addressed by attention_column.
    :param n: the number of samples to retrieve.
    :param time_steps: the number of time steps of your series.
    :param input_dim: the number of dimensions of each element in the series.
    :param attention_column: the column linked to the target. Everything else is purely random.
    :return: x: model inputs, y: model targets
    """
    x = np.random.standard_normal(size=(n, time_steps, input_dim))
    y = np.random.randint(low=0, high=2, size=(n, 1))
    x[:, attention_column, :] = np.tile(y[:], (1, input_dim))
    return x, y

X_train , y_train = get_data_recurrent(300000 , 
                                       input_dim=INPUT_DIMS ,
                                       time_steps=TIME_STEPS , 
                                       attention_column=ATTENTION_COL)

class simple_lstm( nn.Module):
    def __init__(self , input_size , hidden_size , output_units):
        super(simple_lstm, self).__init__()
        self.lstm = nn.LSTM(input_size=10,hidden_size=hidden_size , 
                            batch_first = True)
        self.dense1 = nn.Linear(hidden_size , 1)

    def forward(self , x):
        out , (hn,cn) = self.lstm(x)
        hn = hn.squeeze(0)
        hidden_state = hn
        attention_scores = torch.bmm(out,
                                     hidden_state.unsqueeze(2)).squeeze(2)
        soft_attention_weights = F.softmax(attention_scores, 1) 
        attention_output = torch.bmm(out.transpose(1, 2), 
                                     soft_attention_weights.unsqueeze(2)).squeeze(2)
        out = self.dense1(attention_output)
        out = out
        return out , soft_attention_weights

#Model Training
torch_model = simple_lstm(input_size=INPUT_DIMS ,
                          hidden_size=32, 
                          output_units=1)

optimiser = torch.optim.Adam(params = torch_model.parameters())
criterion = nn.MSELoss()
torch_train = torch.utils.data.TensorDataset(
                        torch.tensor(X_train , dtype = torch.float) , 
                        torch.tensor(y_train , dtype = torch.float))
torch_train_loader = torch.utils.data.DataLoader(torch_train , batch_size=64)

num_epochs = 1
saw = []
for epoch in range(num_epochs):
    for i , (X_tr ,y_tr) in enumerate(torch_train_loader):
        optimiser.zero_grad()
        output , att = torch_model(X_tr)
        saw.append(att.data.numpy().mean(axis=0))
        loss = criterion(output , y_tr)
        loss.backward()
        optimiser.step()
    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1,
                                               num_epochs, loss.item()))

#PLOT ATTENTION  Weights
att = []
for i in range(1,300):
    X_test , y_test  = get_data_recurrent(1,input_dim=INPUT_DIMS ,
                                          time_steps=TIME_STEPS ,
                                          attention_column=ATTENTION_COL)
    preds , attention = torch_model(torch.tensor(X_test , dtype = torch.float))
    att.append(attention.data.numpy())
arr = np.mean(np.array(att) , axis=0)
pd.DataFrame(arr.squeeze(0), columns=['attention (%)']).plot(kind='bar',
                                                             title='Attention Mechanism as '
                                                             'a function of input')

Here are the attention weights from the model


I was expecting to see a higher weight for the 5th sequence ,and not for the last sequence.
Could someone please guide me on this ?