Torch.max returns wrong indices

Hi Team,
Below given is the XOR NN using PyTorch. Looks like the output giving by the max()[1] is wrong. Please review my output as POC.

import torch as th
from torch.autograd import Variable

epochs = 2000
lr = 1
XOR_X = [[0, 0], [0, 1], [1, 0], [1, 1]]
XOR_Y = [[0, 1], [1, 0], [1, 0], [0, 1]]

x_ = Variable(th.FloatTensor(XOR_X), requires_grad=False)
y_ = Variable(th.FloatTensor(XOR_Y), requires_grad=False)

w1 = Variable(th.randn(2, 3), requires_grad=True)
w2 = Variable(th.randn(3, 2), requires_grad=True)

b1 = Variable(th.zeros(3), requires_grad=True)
b2 = Variable(th.zeros(2), requires_grad=True)


def forward(x):
    a2 = x.mm(w1)
    # pytorch didn't have numpy like broadcasting when i wrote this script
    # expand_as make the tensor as similar size as the other tensor
    a2 = a2.add(b1.expand_as(a2))
    h2 = a2.sigmoid()
    a3 = h2.mm(w2)
    a3 = a3.add(b2.expand_as(a3))
    hyp = a3.sigmoid()
    return hyp


for epoch in range(epochs):
    hyp = forward(x_)
    cost = y_ - hyp
    cost = cost.pow(2).sum()
    if epoch % 500 == 0:
        print(cost.data[0])
    cost.backward()
    w1.data -= lr * w1.grad.data
    w2.data -= lr * w2.grad.data
    b1.data -= lr * b1.grad.data
    b2.data -= lr * b2.grad.data
    w1.grad.data.zero_()
    w2.grad.data.zero_()

for x in XOR_X:
    hyp = forward(Variable(th.FloatTensor([x])))
    values, indices = hyp.max(0)
    print('==========================\nX is: ', x)
    print('==========================\n hyp is: ', hyp)
    print('==========================\n indices from argmax: ', indices)

==========================
X is: [0, 0]

hyp is: Variable containing:
0.0166 0.9810
[torch.FloatTensor of size 1x2]

==========================
indices from argmax: Variable containing:
0 0
[torch.LongTensor of size 1x2]

1 Like

I found the same issue with getting the max indices - it would consistently return all zeros.

The issue is that your Tensor is of size (1x2), and you are taking the max over dimension 0 (which has only one element). Take the max over dimension 1 instead

2 Likes

I am so dump :frowning:
Thanks @fmassa