Producing same result from keras

I am trying to rewrite a very simple keras code in pytorch. But I couldn’t produce the exact result in pytorch.
keras code:

from keras.models import Model  
from keras.layers.core import Dense
from keras.layers import Input
from tensorflow.keras import initializers
import torch
def ToyNet(nb_classes=3, img_dim= (2,)):
  model_input = Input(shape=img_dim)    
  x = Dense(125, activation='relu')(model_input)
  x = Dense(125, activation='relu')(x)
  x = Dense(nb_classes)(x)    
  toyNet = Model(inputs=[model_input], outputs=[x], name="DenseNet")
  return toyNet

def fn_minus(correct, predicted):
    y_max = (tf.reduce_max(correct, axis=1) -0.5) #-----------------> only change here
    y_sgm = tf.nn.sigmoid(predicted)
    return tf.nn.softmax_cross_entropy_with_logits(labels=correct, logits=predicted) \
            - y_max*tf.reduce_mean(y_sgm, axis=1)

model_minus = Toy  Net()
model_minus.summary()
opt = keras.optimizers.SGD(learning_rate= 0.01, momentum= 0.9, nesterov= True)
model_minus.compile(loss=fn_minus, optimizer=opt, metrics=["accuracy"])
print("Finished compiling")

####################

Network training

####################
print("Fitting the model … ")
model_minus.fit(x_train, y_train, batch_size= 1200, epochs=5, verbose=1)
print("Done training …model_minus ")

pytorch code:

import torch
from torch import nn
torch.manual_seed(9988)
import random
random.seed(9988)
class torch_ToyNet(nn.Module):

    def __init__(self, in_features=2, nb_classes=3):
      super(torch_ToyNet, self).__init__()
      self.fc1 = nn.Linear(in_features, 125)
      self.fc2 = nn.Linear(125, 125)
      self.fc3 = nn.Linear(125,nb_classes)


    def forward(self, x):
      x = nn.functional.relu(self.fc1(x))
      x = nn.functional.relu(self.fc2(x))
      x = self.fc3(x)
      return x

def torch_fn_minus(correct, predicted):
    y_max = torch.max(correct, dim=1).values-0.5
    y_sgm = torch.sigmoid(predicted)
    return nn.functional.cross_entropy(predicted, correct,,reduction='none')- y_max*torch.mean(y_sgm, axis=1)

model = torch_ToyNet()
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9,nesterov=True)  

for values in range(5):

    pred_y = model(torch.from_numpy(x_train).float())
    loss = torch_fn_minus(torch.from_numpy(y_train), pred_y).mean()
    print(values, loss.item())
    if values % 100 == 99:
        print(values, loss.item())
    optim.zero_grad()
    loss.backward()
    optim.step()

To generate the dataset:

import numpy as np
import matplotlib.pyplot as plt
import os
os.environ['PYTHONHASHSEED'] = '0'
np.random.seed(9988)

def get_data():

    d = 300
    
    cov = [[4, 0], [0, 4]]
    mean1 = [-4, 0]
    x1 = np.random.multivariate_normal(mean1, cov, d)
    y1 = np.zeros([d,3])
    y1[:,0] = 1
    plt.plot(x1[:, 0], x1[:, 1],'bx')
    
    mean2 = [4, 0]
    x2 = np.random.multivariate_normal(mean2, cov, d)
    y2 = np.zeros([d,3])
    y2[:,1] = 1
    plt.plot(x2[:, 0], x2[:, 1], 'yx')
    
    mean3 = [0, 5]
    x3 = np.random.multivariate_normal(mean3, cov, d)
    y3 = np.zeros([d,3])
    y3[:,2] = 1
    plt.plot(x3[:, 0], x3[:, 1], 'rx')
    
    h = 200
    dataX = np.concatenate((x1[:h,:], x2[:h,:], x3[:h,:]), axis = 0)
    dataY = np.concatenate((y1[:h,:], y2[:h,:], y3[:h,:]), axis = 0)
    
    
    dist_val = 6.5
    
    def euclid_dist(i,j, mu):
        
        dist = np.sqrt((i-mu[0])*(i-mu[0]) + (j-mu[1])*(j-mu[1]))
        return dist
    
    count = 0
    while count< 600:
        i = np.random.uniform(-15, 15)
        j = np.random.uniform(-13, 17)
        
        d1 = euclid_dist(i,j,mean1)
        d2 = euclid_dist(i,j,mean2)
        d3 = euclid_dist(i,j,mean3)    
        
        if d1>dist_val and d2>dist_val and d3>dist_val:
            dataX = np.concatenate( (dataX, [[i,j]]), axis=0 )
            dataY = np.concatenate( (dataY, [[0.33, 0.33, 0.33]]), axis=0 )
            count += 1
    
    plt.plot(dataX[h*3:, 0], dataX[h*3:, 1], 'k,')

    plt.plot(-20,20)
    plt.plot(20,-20)
    plt.plot(-20,-20)
    plt.plot(20,20)
    
    plt.xlabel("x")
    plt.ylabel("y")
    
    return dataX, dataY

######################### get data and visualize the in-domain data points ##############

x_train, y_train = get_data()

To have both model same initialization, I have tried this snippet also:

model_minus = ToyNet()
model = torch_ToyNet()
weights = model_minus.get_weights()
model.fc1.weight.data=torch.from_numpy(np.transpose(weights[0]))
model.fc1.bias.data=torch.from_numpy(weights[1])
model.fc2.weight.data=torch.from_numpy(np.transpose(weights[2]))
model.fc2.bias.data=torch.from_numpy(weights[3])
model.fc3.weight.data=torch.from_numpy(np.transpose(weights[4]))
model.fc3.bias.data=torch.from_numpy(weights[5])

But, still results doesn’t match even after 4th epoch. Please help me to figure out what I am doing wrong.
keras result:
Epoch 1/5
1/1 [==============================] - 0s 225ms/step - loss: 1.2290 - accuracy: 0.3342
Epoch 2/5
1/1 [==============================] - 0s 3ms/step - loss: 1.1335 - accuracy: 0.3242
Epoch 3/5
1/1 [==============================] - 0s 4ms/step - loss: 1.0568 - accuracy: 0.3450
Epoch 4/5
1/1 [==============================] - 0s 6ms/step - loss: 1.0024 - accuracy: 0.4467
Epoch 5/5
1/1 [==============================] - 0s 4ms/step - loss: 0.9697 - accuracy: 0.5575

torch result (loss):
0 1.2290266651050497
1 1.1339280305054038
2 1.057344033067301
3 1.0029294972211125
4 0.9699580834043524

I would recommend to scale down the use case first and try to match a single layer, then the model, then the overall training routine, which should make it easier to see where the actual difference is coming from.

Thanks for the response. I will do it.

1m

Hi ptrblck,
I think the difference is happening during parameters update (different gradients are being calculated by keras and Pytorch). Look at the difference in parameter values after a single epoch:
[ 1.30459666e-05 7.01919198e-05 -4.84772027e-05 1.13621354e-05
1.39921904e-05 4.14252281e-06 -1.55717134e-06 -2.32458115e-05
-4.48524952e-06 -7.43567944e-06 -1.77174807e-05 -2.56709754e-05
9.83476639e-07 -1.83135271e-05 1.97440386e-07 1.08666718e-05
-4.64320183e-05 -1.48005784e-05 -1.26659870e-07 1.71661377e-05
5.79208136e-05 1.06766820e-05 -5.73694706e-06 -2.63266265e-05
-2.78651714e-06 -4.74005938e-05 -9.67085361e-06 1.08256936e-05
-5.96046448e-08 -7.15535134e-06 -1.15442090e-05 3.53902578e-05
3.63588333e-06 -3.63811851e-05 -5.70528209e-06 -2.40355730e-05
-2.61813402e-05 4.81307507e-06 2.07275152e-05 -1.30124390e-05
-3.78936529e-05 3.81423160e-06 -2.05934048e-05 3.41087580e-05
-1.45405065e-05 1.57207251e-06 -2.62260437e-06 8.92765820e-06
-4.14252281e-06 4.36902046e-05 -8.24779272e-06 -1.73319131e-06
3.56743112e-06 1.15931034e-05 6.37769699e-06 -1.43423676e-05
4.80562449e-05 3.19778919e-05 1.08033419e-05 2.02804804e-05
-4.77731228e-05 1.77472830e-05 2.56672502e-05 -1.47521496e-06
-4.59700823e-06 -4.17679548e-05 3.40417027e-05 6.02751970e-05
-5.75184822e-06 1.37612224e-05 -1.82026997e-06 7.67409801e-06
1.10268593e-05 7.48038292e-06 -4.33921814e-05 -1.79931521e-05
1.81943178e-05 2.53617764e-05 -5.81145287e-06 -1.72555447e-05
1.91479921e-05 -2.04853714e-05 5.29587269e-05 -1.89188868e-05
5.93811274e-06 3.74242663e-05 -1.87754631e-05 -3.93390656e-06
7.81938434e-06 7.01472163e-06 -2.66544521e-06 2.51282472e-07
-4.64171171e-06 2.47657299e-05 2.35140324e-05 4.26769257e-05
-3.59714031e-05 1.32769346e-05 -7.30156898e-07 5.40837646e-05
-1.23083591e-05 3.46414745e-05 1.25318766e-05 -9.51439142e-06
4.94103879e-05 -1.96248293e-05 -4.27290797e-05 2.08765268e-05
-9.90927219e-06 -1.98334455e-05 -1.75088644e-05 -4.52324748e-05
-1.19209290e-06 7.44033605e-06 -3.07410955e-05 -4.74974513e-07
5.65126538e-06 2.28732824e-05 -7.15255737e-06 1.92523003e-05
-1.94758177e-05 8.58306885e-06 -9.83476639e-06 -1.37090683e-05
2.15172768e-05]

An absolute error of ~1e-5 could be expected for float32 due to its limited floating point precision.
I doubt this small error explains your huge mismatches.