Recreating Keras performance in PyTorch

I am conducting regression on 3D arrays using convolutional networks. I have a model that performs well in Keras, producing an L1 Loss on the validation dataset of 0.006 and plots of predicted-vs.-actual on testing datasets that are satisfactory. I have been trying to set a baseline for this model in PyTorch so I can explore some other architectures and model types whose libraries are written in PyTorch. However, I cannot replicate the performance of the Keras model in PyTorch. I have set everything as identically as I can including:

  1. kernel/weight & bias initalization – set to “glorot_uniform” in Keras which is “xavier_uniform_” in PyTorch
  2. bias initialization – set to “zeros” in Keras, set to zeros in PyTorch
  3. data is exactly the same: I use a DataLoader in PyTorch with batch_size = 16 for training and validation loaders; in keras that is set in the .fit() method as follows:
    cnn.fit(Xtr, Ytr, batch_size=16, epochs=epochs, verbose=2, validation_data=(Xval, Yval), validation_batch_size=16, callbacks=[checkpointcb, stopcallback]))
    a. I thought maybe that my data was not indexed correctly, i.e., X1 with Y3, or that the DataLoader was loading incorrectly. I caught an issue with my data and confirmed that the DataLoader loads indices correctly (why wouldn’t it?). Still no improvement in model performance.
  4. use the Adam optimizer with learning rate = 0.001 (default for both), and eps=1e-07 (default is 1e-08 in PyTorch); all other defaults are the same.

Here are the models.

  1. Keras:
import tensorflow as tf
from tensorflow.keras import layers

input = Input(shape=(1, 16, 16, 16), batch_size=None)

x = layers.Conv3D(64, kernel_size=5, activation='relu', padding='same', data_format='channels_first', kernel_initializer='glorot_uniform')(input)
x = layers.BatchNormalization(axis=1)(x)
x = layers.Dropout(0.25)(x)

res = layers.Conv3D(64, kernel_size=1, padding='valid', data_format='channels_first', kernel_initializer='glorot_uniform')(input)
x = layers.Add()([x, res])

x = layers.Conv3D(64, kernel_size=5, activation='relu', padding='same', data_format='channels_first', kernel_initializer='glorot_uniform')(x)
x = layers.BatchNormalization(axis=1)(x)
x = layers.Dropout(0.25)(x)

res = layers.Conv3D(64, kernel_size=1, padding='valid', data_format='channels_first', kernel_initializer='glorot_uniform')(input)
x = layers.Add()([x, res])

x = layers.Conv3D(128, kernel_size=5, activation='relu', padding='same', data_format='channels_first', kernel_initializer='glorot_uniform')(x)
x = layers.BatchNormalization(axis=1)(x)
x = layers.Dropout(0.25)(x)

x = layers.Flatten()(x)

x = layers.Dense(64, kernel_initializer='glorot_uniform')(x)
x = layers.Dropout(0.25)(x)
x = layers.Dense(32, kernel_initializer='glorot_uniform')(x)
output = layers.Dense(1, kernel_initializer='glorot_uniform')(x)
  1. PyTorch:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNN16(nn.Module):
    def __init__(self, params):
        super(CNN16, self).__init__()

        self.conv1 = nn.Conv3d(1, 64, kernel_size=5, padding=2)
        nn.init.xavier_uniform_(self.conv1.weight)
        nn.init.zeros_(self.conv1.bias)
        self.bn1 = nn.BatchNorm3d(64)
        self.dropout1 = nn.Dropout3d(0.25)

        self.convres = nn.Conv3d(1, 64, kernel_size=1, padding=0)
        nn.init.xavier_uniform_(self.convres.weight)
        nn.init.zeros_(self.convres.bias)

        self.conv2 = nn.Conv3d(64, 64, kernel_size=5, padding=2)
        nn.init.xavier_uniform_(self.conv2.weight)
        nn.init.zeros_(self.conv2.bias)
        self.bn2 = nn.BatchNorm3d(64)
        self.dropout2 = nn.Dropout3d(0.25)
        
        self.conv3 = nn.Conv3d(64, 128, kernel_size=5, padding=2)
        nn.init.xavier_uniform_(self.conv3.weight)
        nn.init.zeros_(self.conv3.bias)
        self.bn3 = nn.BatchNorm3d(128)
        self.dropout3 = nn.Dropout3d(0.25)

        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(128 * 16 * 16 * 16, 64)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        self.dropout4 = nn.Dropout(0.25)

        self.fc2 = nn.Linear(64, 32)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)
        self.output = nn.Linear(32, len(params))
        nn.init.xavier_uniform_(self.output.weight)

    def forward(self, x):
        input = x
        x = self.conv1(x)
        x = F.relu(x)
        x = self.bn1(x)
        x = self.dropout1(x)

        res = self.convres(input)
        x = x + res
        
        x = self.conv2(x)
        x = F.relu(x)
        x = self.bn2(x)
        x = self.dropout2(x)

        res = self.convres(input)
        x = x + res

        x = self.conv3(x)
        x = F.relu(x)
        x = self.bn3(x)
        x = self.dropout3(x)


        x = self.flatten(x)

        x = self.fc1(x)
        x = self.dropout4(x)
        x = self.fc2(x)

        output = self.output(x)

        return output

As you can see, the architecture has 3 convolutional blocks, with residual connections after the first and second layers.
As for performance, the Keras model converges to L1 Loss on the order of 8e-3 in about 150 epochs and improves from there until about 300 epochs. The PyTorch model converges to L1 Loss of about 0.25 and does not change until early stopping criteria is met 150 epochs later. Sometimes validation loss diverges significantly, but that is also a huge mystery why it would do that in PyTorch and not Keras. In Keras, I do not see significant divergence between training and validation loss.

What have I missed? What can I change or try? Any help or ideas are greatly appreciated. I feel like I need a very deep understanding of PyTorch to troubleshoot this, and maybe I just got lucky with Keras’s default settings. Thanks in advance.

I don’t see any obvious differences between the models, but as a quick check you could count the total number of parameters in both implementations and compare these.
Once this is done you could try to load the parameters from Keras into your PyTorch model (I don’t know if there are tools/packages to do it or if this would be a manual process) and compare all intermediate outputs between these models.

Thank you for the advice! I checked and there are exactly the same number of parameters for each model, with the exception of some “non-trainable parameters” that Keras counts from a BatchNormalization method that PyTorch does not. I’ll try to train the Keras model again and figure out how to import the parameters into PyTorch, that’s a great idea.

I am running and rerunning the model with different optimizers and it still converges to a loss of about 0.22 and stays there. Previously with the Keras model when my X and Y data was not matched correctly, i.e. the indices of the ground truth data fed to the model were not the same, I think I got this kind of plateauing behavior. I have checked and rechecked the data and know that what I am feeding to the Dataset object and DataLoader are matched correctly. Is there any way that the Dataloader() and Dataset() wrapped in it could be mixing up indices of the data? Here’s how I have them defined:

class Dataset(torch.utils.data.Dataset):
    def __init__(self, Xarray, Yarray):
        super().__init__()
        self.Xarray=Xarray
        self.Yarray=Yarray
        self.size = Xarray.shape[0]
        self.setup_dataset(self.Xarray, self.Yarray)

    def setup_dataset(self, Xarray, Yarray):
        data = Xarray
        label = Yarray

        self.data = data
        self.label = label

    def __len__(self):
        # Number of data point we have. Alternatively self.data.shape[0], or self.label.shape[0]
        return self.size

    def __getitem__(self, idx):
        data_point = self.data[idx]
        data_label = self.label[idx]
        return data_point, data_label

and how I call them:

trdat = Dataset(torch.from_numpy(Xtr).to(dtype=torch.float), torch.from_numpy(Ytr).to(dtype=torch.float))
trloader = DataLoader(trdat, batch_size=batch_size, shuffle=True)

Thanks for the help!

Update: I exported all of the trained weights of the Keras model and loaded them successfully into the PyTorch model. However, the L1 Loss on the Test dataset is 6,465. Clearly whatever math is going on inside the PyTorch implementation isn’t the same math as is going on in the Keras implementation and I have no idea what to do. Any ideas are welcome.

Do you see the same outputs for the same input data and is thus only the loss calculation different?
If so, do you see any warnings regarding broadcasting operations explaining the shapes of the model output and target do not match?

Maybe it is a channels first vs channels last problem? My biggest gripe with pytorch is that it uses channels first.