PyTorch neural network for classification

The task is to train a neural network to choose the value v in a list for given input x that minimizes (x - v) . For example I have 100x1 dimension tensor of integers as input. My neural network should output probability distribution for those values.

It works properly for 1x1 tensor, but outputs the same distribution for each example in the input. I want it to output proper distribution for every example separately. How should I fix that?

Here is the code:

N, D_in, D_out, H = 20, 1, 5, 10
X = torch.LongTensor(100, 1).random_(0, 500).float()
values = torch.FloatTensor([0, 1, 3, 5, 100])

model = nn.Sequential(
    nn.Linear(D_in, H),
    nn.Linear(H, D_out),

def loss_function(x, y, probability):
    return torch.mean(torch.sum(probability*((x-y)**2), 1).div(probability.size(1)))

loss_fn = loss_function
learning_rate = 1e-3
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

for epoch in range(200):
    pred = model(X)
    loss = loss_fn(X, values, pred)
    print('loss: ', loss.item())

And it is the input and the output for 20 examples:


tensor([[ 16.],
         [ 28.],
         [ 16.],
         [ 13.],
         [ 86.],
         [ 38.],


tensor([[1.2926e-04, 1.3165e-04, 1.2761e-04, 1.3385e-04, 9.9948e-01],
         [9.9158e-05, 1.0086e-04, 9.8492e-05, 1.0170e-04, 9.9960e-01],
         [3.3826e-05, 3.4437e-05, 3.4443e-05, 3.2122e-05, 9.9987e-01],
         [1.2926e-04, 1.3165e-04, 1.2761e-04, 1.3385e-04, 9.9948e-01],
         [3.3173e-05, 3.3778e-05, 3.3758e-05, 3.1419e-05, 9.9987e-01],
         [3.3175e-05, 3.3780e-05, 3.3759e-05, 3.1421e-05, 9.9987e-01],
         [3.3175e-05, 3.3780e-05, 3.3759e-05, 3.1421e-05, 9.9987e-01],
         [3.3170e-05, 3.3775e-05, 3.3754e-05, 3.1416e-05, 9.9987e-01],
         [1.4026e-04, 1.4281e-04, 1.3817e-04, 1.4524e-04, 9.9943e-01],
         [4.4578e-05, 4.5341e-05, 4.5195e-05, 4.3355e-05, 9.9982e-01],
         [8.1636e-05, 8.2994e-05, 8.1468e-05, 8.2810e-05, 9.9967e-01],
         [3.5991e-05, 3.6629e-05, 3.6643e-05, 3.4392e-05, 9.9986e-01],
         [3.4158e-05, 3.4773e-05, 3.4784e-05, 3.2472e-05, 9.9986e-01],
         [3.4881e-05, 3.5505e-05, 3.5521e-05, 3.3232e-05, 9.9986e-01],
         [3.3425e-05, 3.4031e-05, 3.4026e-05, 3.1694e-05, 9.9987e-01],
         [3.6153e-05, 3.6794e-05, 3.6807e-05, 3.4561e-05, 9.9986e-01],
         [3.3191e-05, 3.3796e-05, 3.3777e-05, 3.1439e-05, 9.9987e-01],
         [3.9212e-05, 3.9897e-05, 3.9870e-05, 3.7751e-05, 9.9984e-01],
         [3.6153e-05, 3.6794e-05, 3.6807e-05, 3.4561e-05, 9.9986e-01],
         [3.3207e-05, 3.3812e-05, 3.3795e-05, 3.1457e-05, 9.9987e-01]],