Different behaviour in Numpy and Pytorch

Hi,
Please somebody help me understand why my Pytorch cost not converging while my Numpy does (using the same logic). I am trying to create a fizbuz program similar to Joel’s tensorflow implementation.
I have 2 layer dense network with sigmoid activation in both layers. I am using MSE for cost and same hyperparameters for both numpy and pytorch script. My numpy script converges to the global minima with less than 1k epochs while Pytorch is jumping around even after 5k.

Pytorch implementation:

import numpy as np
import torch as th
from torch.autograd import Variable


input_size = 10
epochs = 1000
batches = 64
lr = 0.01


def binary_enc(num):
    ret = [int(i) for i in '{0:b}'.format(num)]
    return [0] * (input_size - len(ret)) + ret


def binary_dec(array):
    ret = 0
    for i in array:
        ret = ret * 2 + int(i)
    return ret


def training_test_gen(x, y):
    assert len(x) == len(y)
    indices = np.random.permutation(range(len(x)))
    split_size = int(0.9 * len(indices))
    trX = x[indices[:split_size]]
    trY = y[indices[:split_size]]
    teX = x[indices[split_size:]]
    teY = y[indices[split_size:]]
    return trX, trY, teX, teY


def x_y_gen():
    x = []
    y = []
    for i in range(1000):
        x.append(binary_enc(i))
        if i % 15 == 0:
            y.append([1, 0, 0, 0])
        elif i % 5 == 0:
            y.append([0, 1, 0, 0])
        elif i % 3 == 0:
            y.append([0, 0, 1, 0])
        else:
            y.append([0, 0, 0, 1])
    return training_test_gen(np.array(x), np.array(y))


def check_fizbuz(i):
    if i % 15 == 0:
        return 'fizbuz'
    elif i % 5 == 0:
        return 'buz'
    elif i % 3 == 0:
        return 'fiz'
    else:
        return 'number'


trX, trY, teX, teY = x_y_gen()
if th.cuda.is_available():
    dtype = th.cuda.FloatTensor
else:
    dtype = th.FloatTensor
x = Variable(th.from_numpy(trX).type(dtype), requires_grad=False)
y = Variable(th.from_numpy(trY).type(dtype), requires_grad=False)

w1 = Variable(th.randn(10, 100).type(dtype), requires_grad=True)
w2 = Variable(th.randn(100, 4).type(dtype), requires_grad=True)

b1 = Variable(th.zeros(1, 100).type(dtype), requires_grad=True)
b2 = Variable(th.zeros(1, 4).type(dtype), requires_grad=True)

no_of_batches = int(len(trX) / batches)
for epoch in range(epochs):
    for batch in range(no_of_batches):
        start = batch * batches
        end = start + batches
        x_ = x[start:end]
        y_ = y[start:end]

        a2 = x_.mm(w1)
        a2 = a2.add(b1.expand_as(a2))
        h2 = a2.sigmoid()

        a3 = h2.mm(w2)
        a3 = a3.add(b2.expand_as(a3))
        hyp = a3.sigmoid()

        error = hyp - y_
        loss = error.pow(2).sum()
        loss.backward()

        w1.data -= lr * w1.grad.data
        w2.data -= lr * w2.grad.data
        b1.data -= lr * b1.grad.data
        b2.data -= lr * b2.grad.data
        w1.grad.data.zero_()
        w2.grad.data.zero_()
    print(epoch, error.mean().data[0])

Numpy Implementation:

import numpy as np

input_size = 10
epochs = 1000
batches = 64
lr = 0.01


def sig(val):
    return 1 / (1 + np.exp(-val))


def sig_d(val):
    sig_val = sig(val)
    return sig_val * (1 - sig_val)


def binary_enc(num):
    ret = [int(i) for i in '{0:b}'.format(num)]
    return [0] * (input_size - len(ret)) + ret


def binary_dec(array):
    ret = 0
    for i in array:
        ret = ret * 2 + int(i)
    return ret


def training_test_gen(x, y):
    assert len(x) == len(y)
    indices = np.random.permutation(range(len(x)))
    split_size = int(0.9 * len(indices))
    trX = x[indices[:split_size]]
    trY = y[indices[:split_size]]
    teX = x[indices[split_size:]]
    teY = y[indices[split_size:]]
    return trX, trY, teX, teY


def x_y_gen():
    x = []
    y = []
    for i in range(1000):
        x.append(binary_enc(i))
        if i % 15 == 0:
            y.append([1, 0, 0, 0])
        elif i % 5 == 0:
            y.append([0, 1, 0, 0])
        elif i % 3 == 0:
            y.append([0, 0, 1, 0])
        else:
            y.append([0, 0, 0, 1])
    return training_test_gen(np.array(x), np.array(y))


def check_fizbuz(i):
    if i % 15 == 0:
        return 'fizbuz'
    elif i % 5 == 0:
        return 'buz'
    elif i % 3 == 0:
        return 'fiz'
    else:
        return 'number'


trX, trY, teX, teY = x_y_gen()

w1 = np.random.randn(10, 100)
w2 = np.random.randn(100, 4)

b1 = np.zeros((1, 100))
b2 = np.zeros((1, 4))

no_of_batches = int(len(trX) / batches)
for epoch in range(epochs):
    for batch in range(no_of_batches):
        # forward
        start = batch * batches
        end = start + batches
        x = trX[start:end]
        y = trY[start:end]
        a2 = x.dot(w1) + b1
        h2 = sig(a2)
        a3 = h2.dot(w2) + b2
        hyp = sig(a3)
        error = hyp - y
        loss = (error ** 2).mean()

        # backward
        outerror = error
        outgrad = outerror * sig_d(a3)
        outdelta = h2.T.dot(outgrad)
        outbiasdelta = np.ones([1, batches]).dot(outgrad)

        hiddenerror = outerror.dot(w2.T)
        hiddengrad = hiddenerror * sig_d(a2)
        hiddendelta = x.T.dot(hiddengrad)
        hiddenbiasdelta = np.ones([1, batches]).dot(hiddengrad)

        w1 -= hiddendelta * lr
        b1 -= hiddenbiasdelta * lr
        w2 -= outdelta * lr
        b2 -= outbiasdelta * lr
    print(epoch, loss)

# test
a2 = teX.dot(w1) + b1
h2 = sig(a2)
a3 = h2.dot(w2) + b2
hyp = sig(a3)
outli = ['fizbuz', 'buz', 'fiz', 'number']
for i in range(len(teX)):
    num = binary_dec(teX[i])
    print(
        'Number: {} -- Actual: {} -- Prediction: {}'.format(
            num, check_fizbuz(num), outli[hyp[i].argmax()]))
print('Test loss: ', np.mean(teY - hyp))

in torch:

loss = error.pow(2).sum()

in numpy:

loss = (error ** 2).mean()

that’s the point. learning rate is too small, maybe change it to lr*(end-start) or:

loss =error.pow(2).mean()

fixes it in my computer

Note that pytorch also implements the ** operator, so you can effectively have the same code for both torch and numpy.

@chenyuntc I dont see any improvement with changing sum to mean :pensive:
And about changing the learning rate: Changing the hyperparameter will definitely improve the performance coz my NN was still learning. But my problem is not with my networks isolated performance. Numpy give me a loss of 0.009 and acuracy more than 0.98 with the same hyperparameter i used for pytorch. But pytorch is nowhere near.

I thought pow(2) and ** 2 are same and would not affect the output. Will try that anyways.

sorry about the mistake:
I found it.you missed:

        b1.grad.data.zero_()
        b2.grad.data.zero_()

also print

 print(epoch, (error**2).mean().data[0])

you should get similar results as numpy.

1 Like

@chenyuntc
I see some improvements but still far away from Numpy’s. I set the random seed to 10000 and tried to print the weights in each epoch. I could not understand two specific behaviors.

  1. Numpy weights are 8 decimal valued (0.65897509) while pytorch is rounding it to 4 digit (0.6590).
  2. The change in weight values are much faster in numpy

Change of first first weight matrix’s first value in Numpy

0.65897509416461408
0.63824423351700321
0.70742434009324673
0.74590493637361743
0.76070011091500567
0.76704695584086358
0.77166554222431838
0.77654595399345228
0.78221473000011987
0.7887668125254953

Change of first first weight matrix’s first value in Pytorch

0.6590
0.6585
0.6580
0.6575
0.6570
0.6565
0.6561
0.6556
0.6551
0.6547
0.6542

Since both script using seed(1000) we got the same initial value (pytorch rounded it though) but the change is super slow in pytorch

  1. it’s not rounded in torch, it just print in this format
  2. in pytorch, you should use loss=error.pow(2).sum()/2.to get the same grad as numpy----sorry about the mistake.
  3. if it still doesn’t work, you can print w1.grad.data, and compare to the grad you compute. that may be helpful.
1 Like
  1. So when I print it, pytorch rounds the value and print it but for doing computation it uses the actual value?
    I’ll do 2&3 and verify. Thanks @chenyuntc

PyTorch computes using the full precision of the data type, just the displaying that truncates the numbers.
You can change that by modifying set_printoptions

3 Likes

I could see some differences again (not improvement exactly) but still far from Numpy’s accuracy. Trying different approaches and asking around. Will update here if i find some solution. Please let me know if you have some other intuitions.

the diffirence actually comes from your numpy code.

hiddenerror = outerror.dot(w2.T)

which shall be:

hiddenerror = outgrad.dot(w2.T)

Even without modifying this, both pytorch and numpy code should converge to similar results (0.038/0.014). so maybe something else is wrong in your running code.

I use below code to test and get nearlly the same results.

  • pytorch: (999, 0.03878042474389076)
  • numpy: ( 999, 0.038780463080550241)

pytorch code:

import numpy as np
import torch as th
from torch.autograd import Variable


input_size = 10
epochs = 1000
batches = 64
lr = 0.01
np.random.seed(10000)

def binary_enc(num):
    ret = [int(i) for i in '{0:b}'.format(num)]
    return [0] * (input_size - len(ret)) + ret


def binary_dec(array):
    ret = 0
    for i in array:
        ret = ret * 2 + int(i)
    return ret


def training_test_gen(x, y):
    assert len(x) == len(y)
    indices = np.random.permutation(range(len(x)))
    split_size = int(0.9 * len(indices))
    trX = x[indices[:split_size]]
    trY = y[indices[:split_size]]
    teX = x[indices[split_size:]]
    teY = y[indices[split_size:]]
    return trX, trY, teX, teY


def x_y_gen():
    x = []
    y = []
    for i in range(1000):
        x.append(binary_enc(i))
        if i % 15 == 0:
            y.append([1, 0, 0, 0])
        elif i % 5 == 0:
            y.append([0, 1, 0, 0])
        elif i % 3 == 0:
            y.append([0, 0, 1, 0])
        else:
            y.append([0, 0, 0, 1])
    return training_test_gen(np.array(x), np.array(y))


def check_fizbuz(i):
    if i % 15 == 0:
        return 'fizbuz'
    elif i % 5 == 0:
        return 'buz'
    elif i % 3 == 0:
        return 'fiz'
    else:
        return 'number'


trX, trY, teX, teY = x_y_gen()
if th.cuda.is_available():
    dtype = th.cuda.FloatTensor
else:
    dtype = th.FloatTensor
x = Variable(th.from_numpy(trX).type(dtype), requires_grad=False)
y = Variable(th.from_numpy(trY).type(dtype), requires_grad=False)


w1 = np.random.randn(10, 100)
w2 = np.random.randn(100, 4)
w1 = Variable(th.from_numpy(w1).type(dtype), requires_grad=True)
w2 = Variable(th.from_numpy(w2).type(dtype), requires_grad=True)

b1 = Variable(th.zeros(1, 100).type(dtype), requires_grad=True)
b2 = Variable(th.zeros(1, 4).type(dtype), requires_grad=True)

no_of_batches = int(len(trX) / batches)
for epoch in range(epochs):
    for batch in range(no_of_batches):
        start = batch * batches
        end = start + batches
        x_ = x[start:end]
        y_ = y[start:end]

        a2 = x_.mm(w1)
        a2 = a2.add(b1.expand_as(a2))
        h2 = a2.sigmoid()

        a3 = h2.mm(w2)
        a3 = a3.add(b2.expand_as(a3))
        hyp = a3.sigmoid()

        error = hyp - y_
        loss = error.pow(2).sum()/2.0
        loss.backward()

        w1.data -= lr * w1.grad.data
        w2.data -= lr * w2.grad.data
        b1.data -= lr * b1.grad.data
        b2.data -= lr * b2.grad.data
        w1.grad.data.zero_()
        w2.grad.data.zero_()
        b1.grad.data.zero_()
        b2.grad.data.zero_()
    print(epoch, (error**2).mean().data[0])

numpy code:

import numpy as np

input_size = 10
epochs = 1000
batches = 64
lr = 0.01

np.random.seed(10000)
def sig(val):
    return 1 / (1 + np.exp(-val))


def sig_d(val):
    sig_val = sig(val)
    return sig_val * (1 - sig_val)


def binary_enc(num):
    ret = [int(i) for i in '{0:b}'.format(num)]
    return [0] * (input_size - len(ret)) + ret


def binary_dec(array):
    ret = 0
    for i in array:
        ret = ret * 2 + int(i)
    return ret


def training_test_gen(x, y):
    assert len(x) == len(y)
    indices = np.random.permutation(range(len(x)))
    split_size = int(0.9 * len(indices))
    trX = x[indices[:split_size]]
    trY = y[indices[:split_size]]
    teX = x[indices[split_size:]]
    teY = y[indices[split_size:]]
    return trX, trY, teX, teY


def x_y_gen():
    x = []
    y = []
    for i in range(1000):
        x.append(binary_enc(i))
        if i % 15 == 0:
            y.append([1, 0, 0, 0])
        elif i % 5 == 0:
            y.append([0, 1, 0, 0])
        elif i % 3 == 0:
            y.append([0, 0, 1, 0])
        else:
            y.append([0, 0, 0, 1])
    return training_test_gen(np.array(x), np.array(y))


def check_fizbuz(i):
    if i % 15 == 0:
        return 'fizbuz'
    elif i % 5 == 0:
        return 'buz'
    elif i % 3 == 0:
        return 'fiz'
    else:
        return 'number'


trX, trY, teX, teY = x_y_gen()

w1 = np.random.randn(10, 100)
w2 = np.random.randn(100, 4)

b1 = np.zeros((1, 100))
b2 = np.zeros((1, 4))

no_of_batches = int(len(trX) / batches)
for epoch in range(epochs):
    for batch in range(no_of_batches):
        # forward
        start = batch * batches
        end = start + batches
        x = trX[start:end]
        y = trY[start:end]
        a2 = x.dot(w1) + b1
        h2 = sig(a2)
        a3 = h2.dot(w2) + b2
        hyp = sig(a3)
        error = hyp - y
        loss = (error ** 2).mean()

        # backward
        outerror = error
        outgrad = outerror * sig_d(a3)
        outdelta = h2.T.dot(outgrad)
        outbiasdelta = np.ones([1, batches]).dot(outgrad)

        hiddenerror = outgrad.dot(w2.T)
        hiddengrad = hiddenerror * sig_d(a2)
        hiddendelta = x.T.dot(hiddengrad)
        hiddenbiasdelta = np.ones([1, batches]).dot(hiddengrad)

        w1 -= hiddendelta * lr
        b1 -= hiddenbiasdelta * lr
        w2 -= outdelta * lr
        b2 -= outbiasdelta * lr
    print(epoch, loss)

# test
a2 = teX.dot(w1) + b1
h2 = sig(a2)
a3 = h2.dot(w2) + b2
hyp = sig(a3)
outli = ['fizbuz', 'buz', 'fiz', 'number']
for i in range(len(teX)):
    num = binary_dec(teX[i])
    print(
        'Number: {} -- Actual: {} -- Prediction: {}'.format(
            num, check_fizbuz(num), outli[hyp[i].argmax()]))
print('Test loss: ', np.mean(teY - hyp))
1 Like

This is awsome!! I got the same accuracy and loss. Thank you so much @chenyuntc. And Could you please explain the intution behind using

loss = error.pow(2).sum()/2.0

over

loss = error.pow(2).mean()

especially when you use

error**2).mean().data[0]

while printing the error.

And could you describe how it is similar to Numpy’s implementation

loss = (error ** 2).mean()

  • if loss = error.pow(2).sum()/2.0, dloss/derror = error
    if loss = error.pow(2).mean(), dloss/derror = 2*error/(batch_size), your batch_size is 64 here.
    because in numpy implementation, outerror = error, so we should use the first form of loss.

  • I print (error**2).mean().data[0], because you are doing this in numpy

loss = (error ** 2).mean()
...
...
print(epoch, loss)
  • they are the same, but the pytorch code can backward and calculate grad automatically.
2 Likes

@chenyuntc Amazing, you are my hero !!