Unable to replicate Keras model to Pytorch

Hi all,
I am currently working on a DL project where I would like to reproduce the work from an existing research paper. It is about detecting sleep episodes from EEG data using CNN model. The source code is using Keras and I have hard time to reproduce the model in Pytorch.

The existing Keras model runs very fast and after 2 epochs it has the following stats ( This is 4 classes classification problem):

Training loss = 0.431
Training accuracy = 0.842

On validation set:
Kappa score per class = [0.373 0.5571 0.033 0.129]
Overall kappa = 0.319
accuracy = 0.6875
precision = [0.993 0.514 0.029 0.115 ]
recall = [0.696 0.719 0.180 0.573]
f1 = [0.818 0.600 0.05 0.192]

I created the Pytorch version of the model, and run on the same data. It runs much slower and after 6 epochs this is what I got:

Training loss = 0.379899
Training accuracy = 0.869

On validation set:
Kappa score per class = [ 0.111 0.079 -0.016 0.040]
Overall kappa = 0.078
accuracy = 0.559
precision = [0.682 0.432 0.012 0.105]
recall = [0.797 0.106 0.035 0.155 ]
f1 = [0.735 0.170 0.018 0.125]

The only difference in these 2 models are in the initialization,optimizers, and loss function:

  • The Keras model uses Nadam optimizer, the Pytorch model uses SGD(I am not aware that Nadam is available in Pytorch)
  • The Keras model uses glorot-normal for kernel initialization, Pytorch model uses xavier_uniform
  • The Keras model uses “Categorical Cross Entropy” loss function with softmax as the last output layer, Pytorch model uses “CrossEntrropyLoss”. I do not use Softtmax in the last layer since CrossEntropyLoss internally uses LogSoftmax.

I have been spending a lot of time trying to understand why my Pytorch model performs much worse with no avail. I would appreciate if gurus and experts could help advise here :slight_smile:

Here is the Keras model:

def build_model(data_dim, n_channels, n_cl):
	eeg_channels = 1
	act_conv = 'relu'
	init_conv = 'glorot_normal'
	dp_conv = 0.3
	def cnn_block(input_shape):
		input = Input(shape=input_shape)
		x = GaussianNoise(0.0005)(input)
		x = Conv2D(32, (3, 1), strides=(1, 1), padding='same', kernel_initializer=init_conv)(x)
		x = BatchNormalization()(x)
		x = Activation(act_conv)(x)
		x = MaxPooling2D(pool_size=(2, 1), padding='same')(x)
		
		
		x = Conv2D(64, (3, 1), strides=(1, 1), padding='same', kernel_initializer=init_conv)(x)
		x = BatchNormalization()(x)
		x = Activation(act_conv)(x)
		x = MaxPooling2D(pool_size=(2, 1), padding='same')(x)
		for i in range(4):
			x = Conv2D(128, (3, 1), strides=(1, 1), padding='same', kernel_initializer=init_conv)(x)
			x = BatchNormalization()(x)
			x = Activation(act_conv)(x)
			x = MaxPooling2D(pool_size=(2, 1), padding='same')(x)
		for i in range(6):
			x = Conv2D(256, (3, 1), strides=(1, 1), padding='same', kernel_initializer=init_conv)(x)
			x = BatchNormalization()(x)
			x = Activation(act_conv)(x)
			x = MaxPooling2D(pool_size=(2, 1), padding='same')(x)
		flatten1 = Flatten()(x)
		cnn_eeg = Model(inputs=input, outputs=flatten1)
		return cnn_eeg
		
	hidden_units1  = 256
	dp_dense = 0.5

	eeg_channels = 1
	eog_channels = 2

	input_eeg = Input(shape=( data_dim, 1,  3))
	cnn_eeg = cnn_block(( data_dim, 1, 3))
	x_eeg = cnn_eeg(input_eeg)
	x = BatchNormalization()(x_eeg)
	x = Dropout(dp_dense)(x)
	x =  Dense(units=hidden_units1, activation=act_conv, kernel_initializer=init_conv)(x)
	x = BatchNormalization()(x)
	x = Dropout(dp_dense)(x)

	predictions = Dense(units=n_cl, activation='softmax', kernel_initializer=init_conv)(x)

	model = Model(inputs=[input_eeg] , outputs=[predictions])
	return [cnn_eeg, model]

The model is used as follows:

    [cnn_eeg, model] = build_model(data_dim, n_channels, n_cl)
    Nadam = optimizers.Nadam( )
    model.compile(optimizer='Nadam',  loss='categorical_crossentropy', metrics=['accuracy'], sample_weight_mode=None)
    print(cnn_eeg.summary())
    print(model.summary())


        model.fit_generator(generator_train, steps_per_epoch = steps_per_epoch, class_weight = weight, epochs = 1, verbose=1,  callbacks=[history], initial_epoch=0 )

Printout of the model:

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 3200, 1, 3)]      0         
_________________________________________________________________
gaussian_noise (GaussianNois (None, 3200, 1, 3)        0         
_________________________________________________________________
conv2d (Conv2D)              (None, 3200, 1, 32)       320       
_________________________________________________________________
batch_normalization (BatchNo (None, 3200, 1, 32)       128       
_________________________________________________________________
activation (Activation)      (None, 3200, 1, 32)       0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 1600, 1, 32)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 1600, 1, 64)       6208      
_________________________________________________________________
batch_normalization_1 (Batch (None, 1600, 1, 64)       256       
_________________________________________________________________
activation_1 (Activation)    (None, 1600, 1, 64)       0         
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 800, 1, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 800, 1, 128)       24704     
_________________________________________________________________
batch_normalization_2 (Batch (None, 800, 1, 128)       512       
_________________________________________________________________
activation_2 (Activation)    (None, 800, 1, 128)       0         
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 400, 1, 128)       0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 400, 1, 128)       49280     
_________________________________________________________________
batch_normalization_3 (Batch (None, 400, 1, 128)       512       
_________________________________________________________________
activation_3 (Activation)    (None, 400, 1, 128)       0         
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 200, 1, 128)       0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 200, 1, 128)       49280     
_________________________________________________________________
batch_normalization_4 (Batch (None, 200, 1, 128)       512       
_________________________________________________________________
activation_4 (Activation)    (None, 200, 1, 128)       0         
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 100, 1, 128)       0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 100, 1, 128)       49280     
_________________________________________________________________
batch_normalization_5 (Batch (None, 100, 1, 128)       512       
_________________________________________________________________
activation_5 (Activation)    (None, 100, 1, 128)       0         
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 50, 1, 128)        0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 50, 1, 256)        98560     
_________________________________________________________________
batch_normalization_6 (Batch (None, 50, 1, 256)        1024      
_________________________________________________________________
activation_6 (Activation)    (None, 50, 1, 256)        0         
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 25, 1, 256)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 25, 1, 256)        196864    
_________________________________________________________________
batch_normalization_7 (Batch (None, 25, 1, 256)        1024      
_________________________________________________________________
activation_7 (Activation)    (None, 25, 1, 256)        0         
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 13, 1, 256)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 13, 1, 256)        196864    
_________________________________________________________________
batch_normalization_8 (Batch (None, 13, 1, 256)        1024      
_________________________________________________________________
activation_8 (Activation)    (None, 13, 1, 256)        0         
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 7, 1, 256)         0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 7, 1, 256)         196864    
_________________________________________________________________
batch_normalization_9 (Batch (None, 7, 1, 256)         1024      
_________________________________________________________________
activation_9 (Activation)    (None, 7, 1, 256)         0         
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 4, 1, 256)         0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 4, 1, 256)         196864    
_________________________________________________________________
batch_normalization_10 (Batc (None, 4, 1, 256)         1024      
_________________________________________________________________
activation_10 (Activation)   (None, 4, 1, 256)         0         
_________________________________________________________________
max_pooling2d_10 (MaxPooling (None, 2, 1, 256)         0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 2, 1, 256)         196864    
_________________________________________________________________
batch_normalization_11 (Batc (None, 2, 1, 256)         1024      
_________________________________________________________________
activation_11 (Activation)   (None, 2, 1, 256)         0         
_________________________________________________________________
max_pooling2d_11 (MaxPooling (None, 1, 1, 256)         0         
_________________________________________________________________
flatten (Flatten)            (None, 256)               0         
=================================================================
Total params: 1,270,528
Trainable params: 1,266,240
Non-trainable params: 4,288
_________________________________________________________________
None
Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 3200, 1, 3)]      0         
_________________________________________________________________
model (Functional)           (None, 256)               1270528   
_________________________________________________________________
batch_normalization_12 (Batc (None, 256)               1024      
_________________________________________________________________
dropout (Dropout)            (None, 256)               0         
_________________________________________________________________
dense (Dense)                (None, 256)               65792     
_________________________________________________________________
batch_normalization_13 (Batc (None, 256)               1024      
_________________________________________________________________
dropout_1 (Dropout)          (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 4)                 1028      
=================================================================
Total params: 1,339,396
Trainable params: 1,334,084
Non-trainable params: 5,312
_________________________________________________________________

Below is my Pytorch model:

class MSECNN16s(nn.Module):
    
    def __init__(self):
    # input shape = (batchsize, 3, 1, windowsize=16*200=3200)
    
        super(MSECNN16s, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=(1,3), padding=(0,1))
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(1,3), padding=(0,1))
        self.conv3 = nn.Conv2d(64, 128, kernel_size=(1,3), padding=(0,1))
                                
        self.conv4 = nn.ModuleList()
        for x in range(3):
            conv = nn.Conv2d(128, 128, kernel_size=(1,3), padding=(0,1))
            nn.init.xavier_uniform_(conv.weight)
            nn.init.zeros_(conv.bias)
            self.conv4.append(conv)
        
        self.conv5 = nn.Conv2d(128, 256, kernel_size=(1,3), padding=(0,1))        
        
        self.conv6 = nn.ModuleList()
        for x in range(5):
            conv = nn.Conv2d(256, 256, kernel_size=(1,3), padding=(0,1))
            nn.init.xavier_uniform_(conv.weight)
            nn.init.zeros_(conv.bias)
            self.conv6.append(conv)
        
        self.fc1 = nn.Linear(256, 256)
        self.fc2 = nn.Linear(256, 4)
        
        nn.init.xavier_uniform_(self.conv1.weight)
        nn.init.zeros_(self.conv1.bias)
        
        nn.init.xavier_uniform_(self.conv2.weight)
        nn.init.zeros_(self.conv2.bias)
        
        nn.init.xavier_uniform_(self.conv3.weight)
        nn.init.zeros_(self.conv3.bias)
        
        nn.init.xavier_uniform_(self.conv5.weight)
        nn.init.zeros_(self.conv5.bias)
        
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)
        
        

    def forward(self, x):
        # x = (batchsize, 1, 3, windowsize=16*200)
        
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        x = x.permute(0, 2, 1, 3) #convert to (batchsize, 3, 1, 3200)                

        std = 0.0005
        x = (torch.randn_like(x) * std) + x


        x = self.conv1(x)  #output: (batchsize, 32, 1, 3200)
        x = nn.BatchNorm2d(32).to(device)(x)
        x = F.relu(x)
        x = nn.MaxPool2d(kernel_size=(1,2))(x) #output: (batchsize, 32, 1, 1600)
        

        x = self.conv2(x) #output: (batchsize, 64, 1, 1600)
        x = nn.BatchNorm2d(64).to(device)(x)
        x = F.relu(x)
        x = nn.MaxPool2d(kernel_size=(1,2))(x) #output: (batchsize, 64, 1, 800)
                
        x = self.conv3(x) #output: (batchsize, 128, 1, 800)
        x = nn.BatchNorm2d(128).to(device)(x)
        x = F.relu(x)
        x = nn.MaxPool2d(kernel_size=(1,2))(x) #output: (batchsize, 128, 1, 400)
        
        for conv in self.conv4:
            x = conv(x) 
            x = nn.BatchNorm2d(128).to(device)(x)
            x = F.relu(x)
            
            padding = (0,0)
            if x.shape[-1] % 2 > 0: padding = (0,1)
            x = nn.MaxPool2d(kernel_size=(1,2), padding = padding)(x)  # output channels will be 200->100->50
            
        x = self.conv5(x) #output: (batchsize, 256, 1, 50)
        x = nn.BatchNorm2d(256).to(device)(x)
        x = F.relu(x)
        x = nn.MaxPool2d(kernel_size=(1,2))(x) #output: (batchsize, 256, 1, 25)
        
        for conv in self.conv6:
            x = conv(x) 
            x = nn.BatchNorm2d(256).to(device)(x)
            x = F.relu(x)
            
            padding = (0,0)
            if x.shape[-1] % 2 > 0: padding = (0,1)
            x = nn.MaxPool2d(kernel_size=(1,2), padding = padding)(x)  # output channels will be 13->7->4->2->1
        
        # x is (batchsize, 256, 1, 1)
        x = x.squeeze() #x is (batchsize, 256)
        x = nn.BatchNorm1d(256).to(device)(x)
        x = nn.Dropout(p=0.5)(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = nn.BatchNorm1d(256).to(device)(x)
        x = nn.Dropout(p=0.5)(x)
        x = self.fc2(x)        
                
        return x

with the following optimizer and criterion:

    model  = MSECNN16s()
    model.to(device)

    criterion = nn.CrossEntropyLoss(weight=torch.Tensor([1.0, 11.1, 102.9, 38.1]).to(device))

    # load the optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001  )

Any advice is appreciated! many thanks!

Regards
Edwin

I guess it would because of loss function. Here you don’t have categorical cross-entropy loss.
You can use a custom loss function like this one and softmax.

def custom_categorical_cross_entropy(y_pred, y_true):
y_pred = torch.clamp(y_pred, 1e-9, 1 - 1e-9)
return -(y_true * torch.log(y_pred)).sum(dim=1).mean()

Hi, thanks for the reply.
I assume your provided code is for the case of balanced classes. How would it work if we have weights for the different classes. Is the following code correct?

def custom_categorical_cross_entropy(y_pred, y_true, weights):

    weights = weights.view(1,-1)
        
    y_true = F.one_hot(y_true, num_classes=weights.shape[-1])    
    y_pred = torch.clamp(y_pred, 1e-9, 1 - 1e-9)
    
    #nominator = (-y_true * torch.log(y_pred) * weights).sum()
    nominator = (-y_true * torch.log(y_pred)).sum()
    denominator = (y_true * weights).sum()
    
    return nominator / denominator

I am not sure of is whether I should multiply the loss with the class weights?
I am also wondering how this weight is taken into account in Keras, since for my reference model it is passed as argument in model.fit, while I check in keras source code, the weights are not taken into account there. Any idea?

model.compile(optimizer='Nadam',  loss='categorical_crossentropy', metrics=['accuracy'], sample_weight_mode=None)

model.fit_generator(generator_train, steps_per_epoch = steps_per_epoch, class_weight = weight, epochs = 1, verbose=1,  callbacks=[history], initial_epoch=0 )