Huge difference in train-loss of same network and data among pytroch and keras

I am learning PyTorch and are trying to replicate some small networks of tensorflow-keras. Here it is:

## imports 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

## some arguments 
epochs = 10
device = 'cpu'

## making the model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(in_features=8, out_features=32)
        self.relu1 = nn.ReLU(inplace=False)
        self.dropout1 = nn.Dropout(p=0.40)
        self.fc2 = nn.Linear(in_features=32, out_features=1)
        self.relu2 = nn.ReLU(inplace=False)
 
    def forward(self, input_tensor):
        x = self.relu1(self.fc1(input_tensor))
        x = self.dropout1(x)
        x = self.relu2(self.fc2(x))
        return x

## instantiating the model
mlf_net = Net().to(device)
mlf_net

## optimizer with same arguments as tf-keras
optimizer = optim.RMSprop(mlf_net.parameters(), lr=0.001, eps=1e-07)
optimizer

## loss
mse_loss = nn.MSELoss()

Here’s the data:

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split

cal_X, cal_y = fetch_california_housing(return_X_y=True, as_frame=True)

## train-val splits
cal_X_train, cal_X_val, cal_y_train, cal_y_val = train_test_split(cal_X, cal_y,
                                                                  test_size=0.05,
                                                                  random_state=100)

## making datasets for dataloaders
train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(cal_X_train.astype(np.float32).to_numpy()),
                                               torch.from_numpy(cal_y_train.astype(np.float32).to_numpy()))

val_dataset = torch.utils.data.TensorDataset(torch.from_numpy(cal_X_val.astype(np.float32).to_numpy()),
                                             torch.from_numpy(cal_y_val.astype(np.float32).to_numpy()))

## making data-loaders for training
data_kwargs = {'batch_size': 256}

train_loader = torch.utils.data.DataLoader(train_dataset, **data_kwargs)
val_loader = torch.utils.data.DataLoader(val_dataset, **data_kwargs)

training the network now

for epoch in range(1, epochs+1):
    epoch_loss = 0.0
    running_loss = 0.0
    for i, (data, target) in enumerate(train_loader, 0):
        data, target = data.to(device), target.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = mlf_net(data)
        loss = mse_loss(outputs, target.reshape(-1,1))
        loss.backward()
        optimizer.step()

        epoch_loss += outputs.shape[0] * loss.item()

        # print statistics
        # print at the final batch
        running_loss += loss.item()
        if i == len(train_loader):    
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
    
    # print epoch loss
    print(epoch, epoch_loss / len(train_dataset))

and the loss that gets printed is:

1 126.73289673302136
2 12.858341827338105
3 9.97556849218592
4 6.85695126045971
5 6.78658044586275
6 5.878446193968992
7 5.731223358323943
8 10.573452385140166
9 5.647310844687528
10 5.611948392479724

Here’s the very same network in tensorflow-keras:

from tensorflow import keras as tf_keras

## model
tfk_model = tf_keras.Sequential([
                tf_keras.layers.InputLayer(input_shape=(8,)),
                tf_keras.layers.Dense(units=32, activation="relu"),
                tf_keras.layers.Dropout(rate=0.4),
                tf_keras.layers.Dense(units=1, activation="relu")
            ])

## adding optimiser and loss
tfk_model.compile(optimizer='rmsprop',
                  loss='mean_squared_error')

## training the model
tfk_model.fit(x=cal_X_train, y=cal_y_train,
              epochs=2,
              verbose=2)

Here’s the loss that gets printed:

Epoch 1/2
583/583 - 1s - loss: 5.6080 - 1s/epoch - 2ms/step
Epoch 2/2
583/583 - 1s - loss: 5.6075 - 993ms/epoch - 2ms/step

the loss I got with pytroch after 10 epochs is more than the loss I got in the first epoch of tf-keras. What am I doing wrong here that led to this difference in train loss?

versions of both frameworks are, pytorch: 1.11.0+cu113 and tensorflow: 2.8.0;

Are the initial parameters the same? Different parameters will lead to different loss values at the first epoch

1 Like
## with initial parameters generated using similar methods:
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        m.bias.data.fill_(0.0)

## instantiating the model
mlf_net = Net().to(device)
mlf_net.apply(init_weights)

and here are the train losses:

1 7.000502115753124
2 5.634935916953552
3 5.9406873486276455
4 5.663457669126214
5 5.640254720090993
6 5.7117453242068095
7 5.611794123694343
8 5.66521956433281
9 5.613274300939839
10 5.611646902215865
11 5.611850664013505
12 5.611476422134393
13 5.611364200620056
14 5.611674008200189
15 5.610997074528744

the initialisation matches now, but the train loss after 15 epochs still doesn’t match with the train loss of first epoch of keras.


and will you be able to confirm if all things among pytorch’s and keras’s network are the same, except for things that gets influenced by the default arguments of respective frameworks.

I’m not a developer so on the technicals of PyTorch, I’ll default to a Dev answering those specifics.

I do know that PyTorch’s implementation of RMSprop does differ from TF’s implementation via the inclusion of the numerical stability constant within the sqrt for TF and outside the sqrt for Pytorch which can lead to differences in behavior in the final stages of convergence.

Also, don’t use the data attribute of a Tensor directly fill the Tensor instead.

1 Like