Keras to Pytorch model not working as expected

hi there! I converted a keras CNN classifier model to Pytorch. But the model training is not working as expected.

Keras code:

def model():
    model = Sequential()
    model.add(Conv1D(filters=64, kernel_size=6, activation='relu', 
                    padding='same', input_shape=(4, 1)))
    model.add(BatchNormalization())
    
    # adding a pooling layer
    model.add(MaxPooling1D(pool_size=(3), strides=2, padding='same'))
    
    model.add(Conv1D(filters=64, kernel_size=6, activation='relu', 
                    padding='same', input_shape=(4, 1)))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(pool_size=(3), strides=2, padding='same'))
    
    model.add(Conv1D(filters=64, kernel_size=6, activation='relu', 
                    padding='same', input_shape=(4, 1)))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(pool_size=(3), strides=2, padding='same'))
    
    model.add(Flatten())
    model.add(Dense(64, activation='relu'))
    model.add(Dense(64, activation='relu'))
    model.add(Dense(4, activation='softmax'))
    
    model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model
logger = CSVLogger('logs.csv', append=True)
his = model.fit(X_train, y_train, epochs=30, batch_size=32, 
          validation_data=(X_test, y_test), callbacks=[logger])  

Converted pytorch code:

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv1 = nn.Conv1d(4,32, kernel_size=2, padding=2)
        self.bn1 = nn.BatchNorm1d(32)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
        
        self.conv2 = nn.Conv1d(32, 32, kernel_size=2, padding=2)
        self.bn2 = nn.BatchNorm1d(32)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
        
        self.conv3 = nn.Conv1d(32, 32, kernel_size=2, padding=2)
        self.bn3 = nn.BatchNorm1d(32)
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
        
        self.fc1 = nn.Linear(32 * 4, 32)
        self.fc2 = nn.Linear(32, 32)
        self.fc3 = nn.Linear(32, 3)

        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.pool2(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.pool3(x)
        
        x = torch.flatten(x, 1)
        
        x = self.fc1(x)
        x = self.relu(x)
        
        x = self.fc2(x)
        x = self.relu(x)
        
        x = self.fc3(x)
        x = self.softmax(x)
        return x

import torch.optim as optim

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(20):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        #print (inputs.shape)
        #print (labels.shape)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        #if i % 2000 == 1999:    # print every 2000 mini-batches
        print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.6f}')
        #running_loss = 0.0

print('Finished Training')

The keras model trains very well and gives very good accuracy on test data. But, in case of the Pytorch model the loss goes on increasing during training on the same data split. Please let me know what I am doing wrong here.

Hi @Padmaksha_Roy
I could spot a few issues with the torch version:

Issue 1: There are quite a few differences between the Keras model and Torch model.

For example, the first conv layer has 64 filters in Keras but only 32 in Pytorch. The kernel_size also differs.
Another example: the first dense layer self.fc1 has 64 units in Keras but only 32 units in Pytorch.

This model is rather compact - it is possible (and it would be beneficial) to go over each layer and make sure that it’s correctly replicated.

Issue 2: Loss function expects raw logits.
In Pytorch, the Cross Entopy Loss expects raw logits, while you are providing probabilities:

x = self.fc3(x)
x = self.softmax(x)
return x

It would help to remove the softmax, if you plan to use CrossEntropy.
Alternatively, you can keep the softmax and use negative log likelihood which expects log-probabilities: criterion(outputs.log(), labels)

NIY: for a better numerical stability you should use F.log_softmax instead and pass it to nn.NLLLoss.

@Padmaksha_Roy additionally to the differences @sebastian-sz already pointed out you are also applying the relu after the conv in Keras but after the batchnorm layers in PyTorch.

1 Like