Hi all,
I am currently working on a DL project where I would like to reproduce the work from an existing research paper. It is about detecting sleep episodes from EEG data using CNN model. The source code is using Keras and I have hard time to reproduce the model in Pytorch.
The existing Keras model runs very fast and after 2 epochs it has the following stats ( This is 4 classes classification problem):
Training loss = 0.431
Training accuracy = 0.842
On validation set:
Kappa score per class = [0.373 0.5571 0.033 0.129]
Overall kappa = 0.319
accuracy = 0.6875
precision = [0.993 0.514 0.029 0.115 ]
recall = [0.696 0.719 0.180 0.573]
f1 = [0.818 0.600 0.05 0.192]
I created the Pytorch version of the model, and run on the same data. It runs much slower and after 6 epochs this is what I got:
Training loss = 0.379899
Training accuracy = 0.869
On validation set:
Kappa score per class = [ 0.111 0.079 -0.016 0.040]
Overall kappa = 0.078
accuracy = 0.559
precision = [0.682 0.432 0.012 0.105]
recall = [0.797 0.106 0.035 0.155 ]
f1 = [0.735 0.170 0.018 0.125]
The only difference in these 2 models are in the initialization,optimizers, and loss function:
- The Keras model uses Nadam optimizer, the Pytorch model uses SGD(I am not aware that Nadam is available in Pytorch)
- The Keras model uses glorot-normal for kernel initialization, Pytorch model uses xavier_uniform
- The Keras model uses “Categorical Cross Entropy” loss function with softmax as the last output layer, Pytorch model uses “CrossEntrropyLoss”. I do not use Softtmax in the last layer since CrossEntropyLoss internally uses LogSoftmax.
I have been spending a lot of time trying to understand why my Pytorch model performs much worse with no avail. I would appreciate if gurus and experts could help advise here
Here is the Keras model:
def build_model(data_dim, n_channels, n_cl):
eeg_channels = 1
act_conv = 'relu'
init_conv = 'glorot_normal'
dp_conv = 0.3
def cnn_block(input_shape):
input = Input(shape=input_shape)
x = GaussianNoise(0.0005)(input)
x = Conv2D(32, (3, 1), strides=(1, 1), padding='same', kernel_initializer=init_conv)(x)
x = BatchNormalization()(x)
x = Activation(act_conv)(x)
x = MaxPooling2D(pool_size=(2, 1), padding='same')(x)
x = Conv2D(64, (3, 1), strides=(1, 1), padding='same', kernel_initializer=init_conv)(x)
x = BatchNormalization()(x)
x = Activation(act_conv)(x)
x = MaxPooling2D(pool_size=(2, 1), padding='same')(x)
for i in range(4):
x = Conv2D(128, (3, 1), strides=(1, 1), padding='same', kernel_initializer=init_conv)(x)
x = BatchNormalization()(x)
x = Activation(act_conv)(x)
x = MaxPooling2D(pool_size=(2, 1), padding='same')(x)
for i in range(6):
x = Conv2D(256, (3, 1), strides=(1, 1), padding='same', kernel_initializer=init_conv)(x)
x = BatchNormalization()(x)
x = Activation(act_conv)(x)
x = MaxPooling2D(pool_size=(2, 1), padding='same')(x)
flatten1 = Flatten()(x)
cnn_eeg = Model(inputs=input, outputs=flatten1)
return cnn_eeg
hidden_units1 = 256
dp_dense = 0.5
eeg_channels = 1
eog_channels = 2
input_eeg = Input(shape=( data_dim, 1, 3))
cnn_eeg = cnn_block(( data_dim, 1, 3))
x_eeg = cnn_eeg(input_eeg)
x = BatchNormalization()(x_eeg)
x = Dropout(dp_dense)(x)
x = Dense(units=hidden_units1, activation=act_conv, kernel_initializer=init_conv)(x)
x = BatchNormalization()(x)
x = Dropout(dp_dense)(x)
predictions = Dense(units=n_cl, activation='softmax', kernel_initializer=init_conv)(x)
model = Model(inputs=[input_eeg] , outputs=[predictions])
return [cnn_eeg, model]
The model is used as follows:
[cnn_eeg, model] = build_model(data_dim, n_channels, n_cl)
Nadam = optimizers.Nadam( )
model.compile(optimizer='Nadam', loss='categorical_crossentropy', metrics=['accuracy'], sample_weight_mode=None)
print(cnn_eeg.summary())
print(model.summary())
model.fit_generator(generator_train, steps_per_epoch = steps_per_epoch, class_weight = weight, epochs = 1, verbose=1, callbacks=[history], initial_epoch=0 )
Printout of the model:
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 3200, 1, 3)] 0
_________________________________________________________________
gaussian_noise (GaussianNois (None, 3200, 1, 3) 0
_________________________________________________________________
conv2d (Conv2D) (None, 3200, 1, 32) 320
_________________________________________________________________
batch_normalization (BatchNo (None, 3200, 1, 32) 128
_________________________________________________________________
activation (Activation) (None, 3200, 1, 32) 0
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 1600, 1, 32) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 1600, 1, 64) 6208
_________________________________________________________________
batch_normalization_1 (Batch (None, 1600, 1, 64) 256
_________________________________________________________________
activation_1 (Activation) (None, 1600, 1, 64) 0
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 800, 1, 64) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 800, 1, 128) 24704
_________________________________________________________________
batch_normalization_2 (Batch (None, 800, 1, 128) 512
_________________________________________________________________
activation_2 (Activation) (None, 800, 1, 128) 0
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 400, 1, 128) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 400, 1, 128) 49280
_________________________________________________________________
batch_normalization_3 (Batch (None, 400, 1, 128) 512
_________________________________________________________________
activation_3 (Activation) (None, 400, 1, 128) 0
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 200, 1, 128) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 200, 1, 128) 49280
_________________________________________________________________
batch_normalization_4 (Batch (None, 200, 1, 128) 512
_________________________________________________________________
activation_4 (Activation) (None, 200, 1, 128) 0
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 100, 1, 128) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 100, 1, 128) 49280
_________________________________________________________________
batch_normalization_5 (Batch (None, 100, 1, 128) 512
_________________________________________________________________
activation_5 (Activation) (None, 100, 1, 128) 0
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 50, 1, 128) 0
_________________________________________________________________
conv2d_6 (Conv2D) (None, 50, 1, 256) 98560
_________________________________________________________________
batch_normalization_6 (Batch (None, 50, 1, 256) 1024
_________________________________________________________________
activation_6 (Activation) (None, 50, 1, 256) 0
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 25, 1, 256) 0
_________________________________________________________________
conv2d_7 (Conv2D) (None, 25, 1, 256) 196864
_________________________________________________________________
batch_normalization_7 (Batch (None, 25, 1, 256) 1024
_________________________________________________________________
activation_7 (Activation) (None, 25, 1, 256) 0
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 13, 1, 256) 0
_________________________________________________________________
conv2d_8 (Conv2D) (None, 13, 1, 256) 196864
_________________________________________________________________
batch_normalization_8 (Batch (None, 13, 1, 256) 1024
_________________________________________________________________
activation_8 (Activation) (None, 13, 1, 256) 0
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 7, 1, 256) 0
_________________________________________________________________
conv2d_9 (Conv2D) (None, 7, 1, 256) 196864
_________________________________________________________________
batch_normalization_9 (Batch (None, 7, 1, 256) 1024
_________________________________________________________________
activation_9 (Activation) (None, 7, 1, 256) 0
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 4, 1, 256) 0
_________________________________________________________________
conv2d_10 (Conv2D) (None, 4, 1, 256) 196864
_________________________________________________________________
batch_normalization_10 (Batc (None, 4, 1, 256) 1024
_________________________________________________________________
activation_10 (Activation) (None, 4, 1, 256) 0
_________________________________________________________________
max_pooling2d_10 (MaxPooling (None, 2, 1, 256) 0
_________________________________________________________________
conv2d_11 (Conv2D) (None, 2, 1, 256) 196864
_________________________________________________________________
batch_normalization_11 (Batc (None, 2, 1, 256) 1024
_________________________________________________________________
activation_11 (Activation) (None, 2, 1, 256) 0
_________________________________________________________________
max_pooling2d_11 (MaxPooling (None, 1, 1, 256) 0
_________________________________________________________________
flatten (Flatten) (None, 256) 0
=================================================================
Total params: 1,270,528
Trainable params: 1,266,240
Non-trainable params: 4,288
_________________________________________________________________
None
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 3200, 1, 3)] 0
_________________________________________________________________
model (Functional) (None, 256) 1270528
_________________________________________________________________
batch_normalization_12 (Batc (None, 256) 1024
_________________________________________________________________
dropout (Dropout) (None, 256) 0
_________________________________________________________________
dense (Dense) (None, 256) 65792
_________________________________________________________________
batch_normalization_13 (Batc (None, 256) 1024
_________________________________________________________________
dropout_1 (Dropout) (None, 256) 0
_________________________________________________________________
dense_1 (Dense) (None, 4) 1028
=================================================================
Total params: 1,339,396
Trainable params: 1,334,084
Non-trainable params: 5,312
_________________________________________________________________
Below is my Pytorch model:
class MSECNN16s(nn.Module):
def __init__(self):
# input shape = (batchsize, 3, 1, windowsize=16*200=3200)
super(MSECNN16s, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=(1,3), padding=(0,1))
self.conv2 = nn.Conv2d(32, 64, kernel_size=(1,3), padding=(0,1))
self.conv3 = nn.Conv2d(64, 128, kernel_size=(1,3), padding=(0,1))
self.conv4 = nn.ModuleList()
for x in range(3):
conv = nn.Conv2d(128, 128, kernel_size=(1,3), padding=(0,1))
nn.init.xavier_uniform_(conv.weight)
nn.init.zeros_(conv.bias)
self.conv4.append(conv)
self.conv5 = nn.Conv2d(128, 256, kernel_size=(1,3), padding=(0,1))
self.conv6 = nn.ModuleList()
for x in range(5):
conv = nn.Conv2d(256, 256, kernel_size=(1,3), padding=(0,1))
nn.init.xavier_uniform_(conv.weight)
nn.init.zeros_(conv.bias)
self.conv6.append(conv)
self.fc1 = nn.Linear(256, 256)
self.fc2 = nn.Linear(256, 4)
nn.init.xavier_uniform_(self.conv1.weight)
nn.init.zeros_(self.conv1.bias)
nn.init.xavier_uniform_(self.conv2.weight)
nn.init.zeros_(self.conv2.bias)
nn.init.xavier_uniform_(self.conv3.weight)
nn.init.zeros_(self.conv3.bias)
nn.init.xavier_uniform_(self.conv5.weight)
nn.init.zeros_(self.conv5.bias)
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.zeros_(self.fc1.bias)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.zeros_(self.fc2.bias)
def forward(self, x):
# x = (batchsize, 1, 3, windowsize=16*200)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x = x.permute(0, 2, 1, 3) #convert to (batchsize, 3, 1, 3200)
std = 0.0005
x = (torch.randn_like(x) * std) + x
x = self.conv1(x) #output: (batchsize, 32, 1, 3200)
x = nn.BatchNorm2d(32).to(device)(x)
x = F.relu(x)
x = nn.MaxPool2d(kernel_size=(1,2))(x) #output: (batchsize, 32, 1, 1600)
x = self.conv2(x) #output: (batchsize, 64, 1, 1600)
x = nn.BatchNorm2d(64).to(device)(x)
x = F.relu(x)
x = nn.MaxPool2d(kernel_size=(1,2))(x) #output: (batchsize, 64, 1, 800)
x = self.conv3(x) #output: (batchsize, 128, 1, 800)
x = nn.BatchNorm2d(128).to(device)(x)
x = F.relu(x)
x = nn.MaxPool2d(kernel_size=(1,2))(x) #output: (batchsize, 128, 1, 400)
for conv in self.conv4:
x = conv(x)
x = nn.BatchNorm2d(128).to(device)(x)
x = F.relu(x)
padding = (0,0)
if x.shape[-1] % 2 > 0: padding = (0,1)
x = nn.MaxPool2d(kernel_size=(1,2), padding = padding)(x) # output channels will be 200->100->50
x = self.conv5(x) #output: (batchsize, 256, 1, 50)
x = nn.BatchNorm2d(256).to(device)(x)
x = F.relu(x)
x = nn.MaxPool2d(kernel_size=(1,2))(x) #output: (batchsize, 256, 1, 25)
for conv in self.conv6:
x = conv(x)
x = nn.BatchNorm2d(256).to(device)(x)
x = F.relu(x)
padding = (0,0)
if x.shape[-1] % 2 > 0: padding = (0,1)
x = nn.MaxPool2d(kernel_size=(1,2), padding = padding)(x) # output channels will be 13->7->4->2->1
# x is (batchsize, 256, 1, 1)
x = x.squeeze() #x is (batchsize, 256)
x = nn.BatchNorm1d(256).to(device)(x)
x = nn.Dropout(p=0.5)(x)
x = self.fc1(x)
x = F.relu(x)
x = nn.BatchNorm1d(256).to(device)(x)
x = nn.Dropout(p=0.5)(x)
x = self.fc2(x)
return x
with the following optimizer and criterion:
model = MSECNN16s()
model.to(device)
criterion = nn.CrossEntropyLoss(weight=torch.Tensor([1.0, 11.1, 102.9, 38.1]).to(device))
# load the optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001 )
Any advice is appreciated! many thanks!
Regards
Edwin