Implementation of a Keras model doesn't converge

Hello there!
I bet I do something wrong, but after couple of days I’m close to giving up. I try to implement a model described in the following paper: Real-time CNN for Emotion and gender classification. Authors also have implemented the proposed model using Keras. So, I used their implementation as a reference to verify my model. My implementation below:

class MiniXception(nn.Module):
    def __init__(self):
        super(MiniXception, self).__init__()

        self.beginning = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, stride=1, bias=False),
            nn.BatchNorm2d(8, eps=1e-03),
            nn.ReLU(),
            nn.Conv2d(in_channels=8, out_channels=8, kernel_size=3, stride=1, bias=False),
            nn.BatchNorm2d(8),
            nn.ReLU()
        )

        self.core = nn.ModuleList([
            nn.ModuleDict({
                'separable': nn.Sequential(
                    SeparableConv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(num_features=16),
                    nn.ReLU(),
                    SeparableConv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(num_features=16),
                    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                ),
                'residual': nn.Sequential(
                    nn.Conv2d(in_channels=8, out_channels=16, kernel_size=1, stride=2, padding=0, bias=False),
                    nn.BatchNorm2d(num_features=16)
                )
            }),

            nn.ModuleDict({
                'separable': nn.Sequential(
                    SeparableConv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(num_features=32),
                    nn.ReLU(),
                    SeparableConv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(num_features=32),
                    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                ),
                'residual': nn.Sequential(
                    nn.Conv2d(in_channels=16, out_channels=32, kernel_size=1, stride=2, padding=0, bias=False),
                    nn.BatchNorm2d(num_features=32)
                )
            }),
            nn.ModuleDict({
                'separable': nn.Sequential(
                    SeparableConv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(num_features=64),
                    nn.ReLU(),
                    SeparableConv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(num_features=64),
                    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                ),
                'residual': nn.Sequential(
                    nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, stride=2, padding=0, bias=False),
                    nn.BatchNorm2d(num_features=64)
                )
            }),
            nn.ModuleDict({
                'separable': nn.Sequential(
                    SeparableConv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(num_features=128),
                    nn.ReLU(),
                    SeparableConv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(num_features=128),
                    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                ),
                'residual': nn.Sequential(
                    nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1, stride=2, padding=0, bias=False),
                    nn.BatchNorm2d(128)
                )
            })
        ])

        self.final = nn.Sequential(
            nn.Conv2d(128, 7, 3),
            GlobalAvgPooling2d(),
            nn.Softmax(dim=1)
        )

        def initialize(m):
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)

        self.beginning.apply(initialize)
        self.core.apply(initialize)
        self.final.apply(initialize)

    def forward(self, x):
        out = self.beginning(x)
        for m in self.core:
            out_sep = m['separable'](out)
            out_res = m['residual'](out)
            out = out_sep + out_res
        out = self.final(out)
        return out

class GlobalAvgPooling2d(nn.Module):
    def __init__(self):
        super(GlobalAvgPooling2d, self).__init__()

    def forward(self, x):
        return torch.mean(x, (2, 3))

class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
        super(SeparableConv2d, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels,
                               out_channels=in_channels,
                               kernel_size=kernel_size,
                               stride=stride,
                               padding=padding,
                               dilation=dilation,
                               groups=in_channels,
                               bias=bias)

        self.point_wise = nn.Conv2d(in_channels=in_channels,
                                    out_channels=out_channels,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0,
                                    dilation=1,
                                    groups=1,
                                    bias=bias)

    def forward(self, x):
        x = self.conv1(x)
        x = self.point_wise(x)
        return x

Training routine:


def train(model, train_loader, test_loader, num_epochs, criterion, optimizer, scheduler):
    
    loss_history = []
    acc_history = []
    best = float('inf')
    # batch size = 32 as it is in the original implementation
    for epoch in range(num_epochs):
        loss_val = 0.0

        for batch_idx, sample in enumerate(train_loader):
            x, y = sample[0], sample[1]
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()

            loss_val += loss.item()
        
        accuracy = __validate__(model, test_loader)
        scheduler.step(loss_val)
 
        loss_history.append(loss_val)
        acc_history.append(accuracy)

    return loss_history, acc_history


def __validate__(model, data_loader):
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_index, sample in enumerate(data_loader):
            x, y = sample[0], sample[1]
            outputs = model(x)
            _, predicted = torch.max(outputs.data, 1)

            total += y.size(0)
            correct += (predicted == y).sum().item()

    accuracy = 100 * correct / total
    return accuracy

...
# transforms applied to images from fer2013 dataset
transformations = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485], [0.229])])


optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr_init, eps=1e-07)
criterion = torch.nn.CrossEntropyLoss(reduction='mean')
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=self.scheduler_mode,
                                                               patience=self.patience,
                                                               verbose=True)
...

I will post complete summaries on demad, here I suggest looking at the resume for both:

[Keras]
Total params: 58,423
Trainable params: 56,951
Non-trainable params: 1,472

[Torch]
Total params: 56,951
Trainable params: 56,951
Non-trainable params: 0

As we can see, original implementation has 1,472 non-trainable params. Going through layers I found, that keras BatchNorm has 2x params comparing to pytorch BatchNorm layer. I don’t yet understand why, though.

So far I’ve tried to reduce my dataset to only 10 samples and get loss=0, but it gets stuck around 1.16 and never goes down. Here is a plot:
loss_func

Learning on full dataset results in ~1.3 loss value with zero to no changes after 300 epoch. Is there something I do obviously wrong here?

Any help is appreciated, sorry for such a wall of text, not sure which information is crucial yet.

UPD:
Difference between # of parameters in keras and pytorch BatchNorm explained here: https://stackoverflow.com/questions/60079783/difference-between-keras-batchnormalization-and-pytorchs-batchnorm2d

As per docs, when using ReduceLROnPlateau,

>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = ReduceLROnPlateau(optimizer, 'min')
>>> for epoch in range(10):
>>>     train(...)
>>>     val_loss = validate(...)
>>>     # Note that step should be called after validate()
>>>     scheduler.step(val_loss)

But you are not having that correct imo; Also in my exprience too, it’s slightly tricky to match the performance with a keras model, it generally needs param or arch changes sometimes as well;

Hm, I’m pretty sure my validate func is ‘pure’ and in docs they are talking about a loss function, criterion as it returns loss value. Am I wrong?

Yep it’s true criterion returns a loss_value but reduce-on-loss-plateau-decay monitors some kind of “val_metric” not some loss value; Also in our case, is that value coming from validate in your case?

Sample func signature,

# lr = lr * factor 
# mode='max': look for the maximum validation accuracy (a metric) to track
# patience: number of epochs - 1 where loss plateaus before decreasing LR
        # patience = 0, after 1 bad epoch, reduce LR
# factor = decaying factor
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=0, verbose=True)

I meant, torch.optim.lr_scheduler.ReduceLROnPlateau allows dynamic learning rate reducing based on some validation measurements… (but here you passed in train_loss, right?)

PS I might be wrong with this thing, so please take it with a pinch of salt;

Xtra Notes, (end-to-end MNIST Example modified from examples dir in PyTorch)

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets

# Set seed
torch.manual_seed(0)

# Where to add a new import
from torch.optim.lr_scheduler import ReduceLROnPlateau

'''
STEP 1: LOADING DATASET
'''

train_dataset = dsets.MNIST(root='./data', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

'''
STEP 2: MAKING DATASET ITERABLE
'''

batch_size = 100
n_iters = 6000
num_epochs = n_iters / (len(train_dataset) / batch_size)
num_epochs = int(num_epochs)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

'''
STEP 3: CREATE MODEL CLASS
'''
class FeedforwardNeuralNetModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FeedforwardNeuralNetModel, self).__init__()
        # Linear function
        self.fc1 = nn.Linear(input_dim, hidden_dim) 
        # Non-linearity
        self.relu = nn.ReLU()
        # Linear function (readout)
        self.fc2 = nn.Linear(hidden_dim, output_dim)  

    def forward(self, x):
        # Linear function
        out = self.fc1(x)
        # Non-linearity
        out = self.relu(out)
        # Linear function (readout)
        out = self.fc2(out)
        return out
'''
STEP 4: INSTANTIATE MODEL CLASS
'''
input_dim = 28*28
hidden_dim = 100
output_dim = 10

model = FeedforwardNeuralNetModel(input_dim, hidden_dim, output_dim)

'''
STEP 5: INSTANTIATE LOSS CLASS
'''
criterion = nn.CrossEntropyLoss()


'''
STEP 6: INSTANTIATE OPTIMIZER CLASS
'''
learning_rate = 0.1

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)

'''
STEP 7: INSTANTIATE STEP LEARNING SCHEDULER CLASS
'''
# lr = lr * factor 
# mode='max': look for the maximum validation accuracy to track
# patience: number of epochs - 1 where loss plateaus before decreasing LR
        # patience = 0, after 1 bad epoch, reduce LR
# factor = decaying factor
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=0, verbose=True)

'''
STEP 7: TRAIN THE MODEL
'''
iter = 0
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Load images as Variable
        images = images.view(-1, 28*28).requires_grad_()

        # Clear gradients w.r.t. parameters
        optimizer.zero_grad()

        # Forward pass to get output/logits
        outputs = model(images)

        # Calculate Loss: softmax --> cross entropy loss
        loss = criterion(outputs, labels)

        # Getting gradients w.r.t. parameters
        loss.backward()

        # Updating parameters
        optimizer.step()

        iter += 1

        if iter % 500 == 0:
            # Calculate Accuracy         
            correct = 0
            total = 0
            # Iterate through test dataset
            for images, labels in test_loader:
                # Load images to a Torch Variable
                images = images.view(-1, 28*28)

                # Forward pass only to get logits/output
                outputs = model(images)

                # Get predictions from the maximum value
                _, predicted = torch.max(outputs.data, 1)

                # Total number of labels
                total += labels.size(0)

                # Total correct predictions
                # Without .item(), it is a uint8 tensor which will not work when you pass this number to the scheduler
                correct += (predicted == labels).sum().item()

            accuracy = 100 * correct / total

            # Print Loss
            # print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.data[0], accuracy))

    # Decay Learning Rate, pass validation accuracy for tracking at every epoch
    print('Epoch {} completed'.format(epoch))
    print('Loss: {}. Accuracy: {}'.format(loss.item(), accuracy))
    print('-'*20)
    scheduler.step(accuracy)

Thank you for your reply in first place!

In my case I use the default mode min and optimize loss function rather than accuracy, the latter (which is returned from my validate function) is used only for information I print each epoch:

INFO: --------------------
INFO: Epoch 293 completed
INFO: Loss: 1.3087555471597774. Accuracy: 45.57696772695612
INFO: --------------------
INFO: Epoch 294 completed
INFO: Loss: 1.3081310525009235. Accuracy: 45.57696772695612
INFO: --------------------
Epoch   295: reducing learning rate of group 0 to 1.0000e-08.
INFO: Epoch 295 completed
INFO: Loss: 1.3085813283295338. Accuracy: 45.623403761318784
INFO: --------------------
INFO: Epoch 296 completed
INFO: Loss: 1.3085084612997613. Accuracy: 44.78755514279081

And, as far as I can see, my lr is changed when I expect it to be changed (having patience set to 10, after 10 bad epochs it gets reduced by a factor of 0.1 up to smallest one 1e-08, but it seems that I stuck somewhere in a local minimum of loss surface. That said, I suppose my mistake is somewhere around L2 regularisations that are applied in keras Conv2d layers.

As I understood, in Pytorch we need to set weight_decay parameter in our optimizer to achieve the same result, but I’m far from being confident here.

I see! Your arch looks okay to me as per paper; Try lowering the LR maybe?

As I understood, in Pytorch we need to set weight_decay parameter in our optimizer to achieve the same result, but I’m far from being confident here.

Yep!

Summaries for both models can be found under the following links, just in case.
https://gist.github.com/Nbooo/7d3f6f5c4dc8768f765b02408b66e6d0 (pytorch)
https://gist.github.com/Nbooo/2b47b384c20e2dfb5e1224083b2ca2f3 (keras)

The weight_decay might be too aggressive in PyTorch, as it’ll add all parameters to the regularization term.
Have a look at this post to exclude batch norm params etc. (or just add the conv parameters to the regularization).

1 Like

Thanks for the link, I’ve tried to follow the suggestion to exclude batch norm params from regularization. Doesn’t seem to help, though.

As someone else mentioned, the metric for plateau LR scheduler should be from validation, not train, regardless of whether it is an accuracy, loss, or other metric. Defeats the point otherwise as train loss can easily continue to fall while validation loss starts increasing once you overfit.

Overall, matching training results across different frameworks isn’t always straightforward. What looks like the same network + optimizer, etc can be different enough to have an impact. Differences in image processing (even interpolation), normalization, optimizer params or implementation (momentums, epsilons, etc), layer params or impl (epsilons, momentums), can all have an impact. And also subtle mistakes like using the wrong combo of output softmax/sigmoid vs loss fn as one framework might name something differently or have different canonical examples than another.

1 Like

What a useful comment, thank you very much! Indeed I misused the scheduler. I’ll refine the learning procedure and see if it helps!

I understand that, I tried to pay attention to all the details.

Also try checking out the [default] weight init’s!

1 Like