Model cannot learn useful features (Best approaches to debug) - test loss is always much lower than train

Hi there,

I have created an autoencoder called CoarseNet and its duty is to work as a low pass filter for images and only reconstruct the main features of the input images. Here are examples of input images and ground truth ones:

input BEFORE transformations

in

ground truth AFTER transformations

gt

output after unnormalization

out

And here is the loss value:

train loss:
epoch: 1, loss: 62.551551818847656
epoch: 10, loss: 36.72643280029297
epoch: 20, loss: 26.012357711791992

test loss:
loss: 19.805

And this scale always exists even when I do not train the model at all.
My code is available in GitHub.

I have pasted the code here too.

The model I am training - modified version of UNet

# %% Import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F


# %% Submodules
class CL(nn.Module):
    def __init__(self, input_channel, output_channel):
        """
        It consists of the 4x4 convolutions with stride=2, padding=1, each followed by
        a leaky rectified linear unit (Leaky ReLU)

        :param input_channel: input channel size
        :param output_channel: output channel size
        """

        assert (input_channel > 0 and output_channel > 0)

        super(CL, self).__init__()
        layers = [nn.Conv2d(input_channel, output_channel, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


# %%
class CBL(nn.Module):
    def __init__(self, input_channel, output_channel):
        """
        It consists of the 4x4 convolutions with stride=2, padding=1, and a batch normalization, followed by
        a leaky rectified linear unit (ReLU)

        :param input_channel: input channel size
        :param output_channel: output channel size
        """
        assert (input_channel > 0 and output_channel > 0)

        super(CBL, self).__init__()
        layers = [nn.Conv2d(input_channel, output_channel, kernel_size=4, stride=2, padding=1),
                  nn.BatchNorm2d(num_features=output_channel), nn.LeakyReLU(0.2, inplace=True)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


# %%
class CE(nn.Module):
    def __init__(self, input_channel, output_channel, ks=4, s=2):
        """
        It consists of the 4x4 convolutions with stride=2, padding=1, each followed by
        a exponential linear unit (ELU)

        :param input_channel: input channel size
        :param output_channel: output channel size
        :param ks: kernel size
        :param s: stride size
        """
        assert (input_channel > 0 and output_channel > 0)

        super(CE, self).__init__()
        layers = [nn.ConvTranspose2d(input_channel, output_channel, kernel_size=ks, stride=s, padding=1),
                  nn.ELU(alpha=1, inplace=True)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


# %%
class Contract(nn.Module):
    def __init__(self, input_channel, output_channel, module='cbl'):
        """
        It consists of a CL or CBL followed by a 2x2 MaxPooling operation with stride 2 for down sampling.


        :param input_channel: input channel size
        :param output_channel: output channel size
        :param module: using Convolution->ReLU (CL class) or Convolution->BathNorm->ReLU (CBL class)
                Convolution->ELU (CE class) for first layer of Expand (decoder) path
        """

        assert (input_channel > 0 and output_channel > 0)

        super(Contract, self).__init__()

        layers = []
        if module == 'cl':
            layers.append(CL(input_channel, output_channel))
        elif module == 'ce':
            layers.append(CE(input_channel, output_channel))
        else:
            layers.append(CBL(input_channel, output_channel))

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


# %%
class Expand(nn.Module):
    def __init__(self, input_channel, output_channel, ks=4, s=2):
        """
        This path consists of an up sampling of the feature map followed by a
        4x4 convolution ("up-convolution" or Transformed Convolution) that halves the number of
        feature channels, a concatenation with the correspondingly cropped feature map from Contract phase


        :param input_channel: input channel size
        :param output_channel: output channel size
        """
        super(Expand, self).__init__()
        self.layers = CE(input_channel * 2, output_channel, ks, s)

    def forward(self, x1, x2):
        delta_x = x1.size()[2] - x2.size()[2]
        delta_y = x1.size()[3] - x2.size()[3]
        x2 = F.pad(x2, pad=(delta_x // 2, delta_y // 2, delta_x // 2, delta_y // 2), mode='constant', value=0)
        x = torch.cat((x2, x1), dim=1)
        x = self.layers(x)
        return x


# %%
class C(nn.Module):
    def __init__(self, input_channel, output_channel):
        """
        At the final layer, a 3x3 convolution is used to map each 64-component feature vector to the desired
        number of classes.

        :param input_channel: input channel size
        :param output_channel: output channel size
        """
        super(C, self).__init__()
        self.layer = nn.Conv2d(input_channel, output_channel, kernel_size=3, padding=1, stride=1)

    def forward(self, x):
        return self.layer(x)


#%% Main CLass
class CoarseNet(nn.Module):
    def __init__(self, input_channels=3, output_channels=3):
        """
        Implementation of CoarseNet, a modified version of UNet.
        (https://arxiv.org/abs/1505.04597 - Convolutional Networks for Biomedical Image Segmentation (Ronneberger et al., 2015))

        :param input_channels: number of input channels of input images to network.
        :param output_channels: number of output channels of output images of network.
        """

        super(CoarseNet, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels

        # Encoder
        self.cl0 = Contract(input_channels, 64, module='cl')
        self.cbl0 = Contract(64, 128)
        self.cbl1 = Contract(128, 256)
        self.cbl2 = Contract(256, 512)
        self.cl1 = Contract(512, 512, module='cl')

        # Decoder
        self.ce0 = Contract(512, 512, module='ce')
        self.ce1 = Expand(512, 256)
        self.ce2 = Expand(256, 128)
        self.ce3 = Expand(128, 64)
        self.ce4 = Expand(64, 64)
        self.ce5 = CE(64, 64, ks=3, s=1)

        # final
        self.final = C(64, self.output_channels)

    def forward(self, x):
        out = self.cl0(x)  # 3>64
        out2 = self.cbl0(out)  # 64>128
        out3 = self.cbl1(out2)  # 128>256
        out4 = self.cbl2(out3)  # 256>512
        out5 = self.cl1(out4)  # 512>512
        in0 = self.ce0(out5)

        in1 = self.ce1(out4, in0)  # 512>512
        in2 = self.ce2(out3, in1)  # 512>256
        in3 = self.ce3(out2, in2)  # 256>128
        in4 = self.ce4(out, in3)  # 128>64
        f = self.ce5(in4)
        f = self.final(f)
        return f

The function which generates input from ground truth images

# %% libraries
import PIL.Image as Image
import numpy.matlib
import numpy as np
import random
import math

import PIL.Image as Image
import matplotlib.pyplot as plt
import numpy as np


def error_diffusion(pixel, size=(1, 1)):
    for y in range(0, size[1] - 1):
        for x in range(1, size[0] - 1):
            oldpixel = pixel[x, y]
            if oldpixel > 127:
                pixel[x, y] = 255
            else:
                pixel[x, y] = 0

            quant_error = oldpixel - pixel[x, y]
            pixel[x + 1, y] = pixel[x + 1, y] + int(7 / 16.0 * quant_error)
            pixel[x - 1, y + 1] = pixel[x - 1, y + 1] + int(3 / 16.0 * quant_error)
            pixel[x, y + 1] = pixel[x, y + 1] + int(5 / 16.0 * quant_error)
            pixel[x + 1, y + 1] = pixel[x + 1, y + 1] + int(1 / 16.0 * quant_error)

def generate_halftone(img):
  img = img.convert('CMYK')
  img = img.split()
  dots = []
  for chan in img:
    error_diffusion(chan.load(), chan.size)
  img = Image.merge("CMYK", img).convert("RGB")
  return img

Custom Dataset

from __future__ import print_function, division
from PIL import Image
from skimage import feature, color
from torchvision.transforms import ToTensor, ToPILImage, Compose
import numpy as np
import random

import tarfile
import io
import os
import pandas as pd

from torch.utils.data import Dataset
import torch


class PlacesDataset(Dataset):
    def __init__(self, txt_path='filelist.txt', img_dir='data', transform=None, test=False):
        """
        Initialize data set as a list of IDs corresponding to each item of data set

        :param img_dir: path to image files as a uncompressed tar archive
        :param txt_path: a text file containing names of all of images line by line
        :param transform: apply some transforms like cropping, rotating, etc on input image
        :param test: is inference time or not
        :return a 3-value dict containing input image (y_descreen) as ground truth, input image X as halftone image
                and edge-map (y_edge) of ground truth image to feed into the network.
        """

        df = pd.read_csv(txt_path, sep=' ', index_col=0)
        self.img_names = df.index.values
        self.txt_path = txt_path
        self.img_dir = img_dir
        self.transform = transform
        self.to_tensor = ToTensor()
        self.to_pil = ToPILImage()
        self.get_image_selector = True if img_dir.__contains__('tar') else False
        self.tf = tarfile.open(self.img_dir) if self.get_image_selector else None
        self.transform_gt = transform if test else Compose(self.transform.transforms[:-1])  # omit noise of ground truth

    def get_image_from_tar(self, name):
        """
        Gets a image by a name gathered from file list csv file

        :param name: name of targeted image
        :return: a PIL image
        """
        image = self.tf.extractfile(name)
        image = image.read()
        image = Image.open(io.BytesIO(image))
        return image

    def get_image_from_folder(self, name):
        """
        gets a image by a name gathered from file list text file

        :param name: name of targeted image
        :return: a PIL image
        """

        image = Image.open(os.path.join(self.img_dir, name))
        return image

    def __len__(self):
        """
        Return the length of data set using list of IDs

        :return: number of samples in data set
        """
        return len(self.img_names)

    def __getitem__(self, index):
        """
        Generate one item of data set. Here we apply our preprocessing things like halftone styles and
        subtractive color process using CMYK color model, generating edge-maps, etc.

        :param index: index of item in IDs list

        :return: a sample of data as a dict
        """

        if index == (self.__len__() - 1) and self.get_image_selector:  # close tarfile opened in __init__
            self.tf.close()

        if self.get_image_selector:  # note: we prefer to extract then process!
            y_descreen = self.get_image_from_tar(self.img_names[index])
        else:
            y_descreen = self.get_image_from_folder(self.img_names[index])

        # generate halftone image
        X = generate_halftone(y_descreen)

        seed = np.random.randint(2147483647)
        random.seed(seed)

        if self.transform is not None:
            X = self.transform(X)
            random.seed(seed)
            torch.manual_seed(seed)
            y_descreen = self.transform_gt(y_descreen)

        sample = {'X': X,
                  'y_descreen': y_descreen}

        return sample


class RandomNoise(object):
    def __init__(self, p, mean=0, std=0.1):
        self.p = p
        self.mean = mean
        self.std = std

    def __call__(self, img):
        if random.random() <= self.p:
            noise = torch.empty(*img.size(), dtype=torch.float, requires_grad=False)
            return img+noise.normal_(self.mean, self.std)
        return img


class UnNormalizeNative(object):
    """
    Unnormalize an input tensor given the mean and std
    """

    def __init__(self, mean, std):
        self.mean = torch.tensor(mean)
        self.std = torch.tensor(std)

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """

        return Normalize((-mean / std).tolist(), (1.0 / std).tolist())(tensor)

Train and Test methods

# %% import library
# from CoarseNet import CoarseNet
from pix2pix_unet import G
from torchvision.transforms import Compose, ToPILImage, ToTensor, RandomResizedCrop, RandomRotation, \
    RandomHorizontalFlip, Normalize
# from utils.preprocess import *
import torch
from torch.utils.data import DataLoader
from utils.Loss import CoarseLoss

import torch.optim as optim
import torch.nn as nn
from torch.backends import cudnn

import argparse

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def init_weights(m):
    """
    Initialize weights of layers using Kaiming Normal (He et al.) as argument of "Apply" function of
    "nn.Module"

    :param m: Layer to initialize
    :return: None
    """

    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.kaiming_normal_(m.weight, mode='fan_out')
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):  # reference: https://github.com/pytorch/pytorch/issues/12259
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)


# %% train model
def train_model(net, data_loader, optimizer, criterion, epochs=2):
    """
    Train model

    :param net: Parameters of defined neural network
    :param data_loader: A data loader object defined on train data set
    :param epochs: Number of epochs to train model
    :param optimizer: Optimizer to train network
    :param criterion: The loss function to minimize by optimizer
    :return: None
    """

    net.train()
    for epoch in range(epochs):

        running_loss = 0.0
        for i, data in enumerate(data_loader, 0):

            X = data['X']
            y_d = data['y_descreen']

            X = X.to(device)
            y_d = y_d.to(device)

            optimizer.zero_grad()

            outputs = net(X)
            loss = criterion(outputs, y_d)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            print(epoch + 1, ',', i + 1, 'loss:', running_loss)
    print('Finished Training')


# %% test
def test_model(net, data_loader):
    """
    Return loss on test

    :param net: The trained NN network
    :param data_loader: Data loader containing test set
    :return: Print loss value over test set in console
    """

    net.eval()
    running_loss = 0.0
    with torch.no_grad():
        for data in data_loader:
            X = data['X']
            y_d = data['y_descreen']
            X = X.to(device)
            y_d = y_d.to(device)
            outputs = net(X)
            loss = criterion(outputs, y_d)
            running_loss += loss
            print('loss: %.3f' % running_loss)
    return outputs


def show_image_batch(image_batch, name='out.png'):
    """
    Show a sample grid image which contains some sample of test set result

    :param image_batch: The output batch of test set
    :return: PIL image of all images of the input batch
    """

    to_pil = ToPILImage()
    fs = []
    for i in range(len(image_batch)):
        img = to_pil(image_batch[i].cpu())
        fs.append(img)
    x, y = fs[0].size
    ncol = int(np.ceil(np.sqrt(len(image_batch))))
    nrow = int(np.ceil(np.sqrt(len(image_batch))))
    cvs = Image.new('RGB', (x * ncol, y * nrow))
    for i in range(len(fs)):
        px, py = x * int(i / nrow), y * (i % nrow)
        cvs.paste((fs[i]), (px, py))
    cvs.save(name, format='png')
    cvs.show()


# parser = argparse.ArgumentParser()
# parser.add_argument("--txt", help='path to the text file', default='filelist.txt')
# parser.add_argument("--img", help='path to the images tar(bug!) archive (uncompressed) or folder', default='data')
# parser.add_argument("--txt_t", help='path to the text file of test set', default='filelist.txt')
# parser.add_argument("--img_t", help='path to the images tar archive (uncompressed) of testset ', default='data')
# parser.add_argument("--bs", help='int number as batch size', default=128, type=int)
# parser.add_argument("--es", help='int number as number of epochs', default=10, type=int)
# parser.add_argument("--nw", help='number of workers (1 to 8 recommended)', default=4, type=int)
# parser.add_argument("--lr", help='learning rate of optimizer (=0.0001)', default=0.0001, type=float)
# parser.add_argument("--cudnn", help='enable(1) cudnn.benchmark or not(0)', default=0, type=int)
# parser.add_argument("--pm", help='enable(1) pin_memory or not(0)', default=0, type=int)
# args = parser.parse_args()


class args:
  txt='filelist.txt'
  img='data'
  txt_t='filelist.txt'
  img_t='data'
  bs=9
  es=20
  nw=0
  lr=0.0001
  cudnn=0
  pm=0
  


if args.cudnn == 1:
    cudnn.benchmark = True
else:
    cudnn.benchmark = False

if args.pm == 1:
    pin_memory = True
else:
    pin_memory = False

# %% get dataset specific mean and std values
train_dataset = PlacesDataset(txt_path=args.txt,
                              img_dir=args.img,
                              transform=ToTensor(),
                              test=True)

mean, std = OnlineMeanStd()(train_dataset, batch_size=1, method='strong')


# %% define data sets and their loaders
custom_transforms = Compose([
    RandomResizedCrop(size=256, scale=(0.8, 1.2)),
    RandomRotation(degrees=(-30, 30)),
    RandomHorizontalFlip(p=0.5),
    ToTensor(),
    Normalize(mean=mean, std=std),
    RandomNoise(p=0.0, mean=0, std=0.1)])

train_dataset = PlacesDataset(txt_path=args.txt,
                              img_dir=args.img,
                              transform=custom_transforms)

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=args.bs,
                          shuffle=True,
                          num_workers=args.nw,
                          pin_memory=pin_memory)

test_dataset = PlacesDataset(txt_path=args.txt_t,
                             img_dir=args.img_t,
                             transform=ToTensor(),
                             test=True)

test_loader = DataLoader(dataset=test_dataset,
                         batch_size=args.bs,
                         shuffle=False,
                         num_workers=args.nw,
                         pin_memory=pin_memory)

# %% initialize network, loss and optimizer
criterion = CoarseLoss(w1=50, w2=0).to(device)
coarsenet = CoarseNet().to(device)
optimizer = optim.Adam(coarsenet.parameters(), lr=args.lr)
coarsenet.apply(init_weights)  # initilializing weights only helps color and won't remove patterns
train_model(coarsenet, train_loader, optimizer, criterion, epochs=args.es)
o = test_model(coarsenet, test_loader)

Loss function (in demonstrated examples w2=0)

# %% libraries
import torch.nn as nn
import torch
from vgg import vgg16_bn
import numpy as np


class CoarseLoss(nn.Module):
    def __init__(self, w1=50, w2=1, weight_vgg=None):
        """
        A weighted sum of pixel-wise L1 loss and sum of L2 loss of Gram matrices.

        :param w1: weight of L1  (pixel-wise)
        :param w2: weight of L2 loss (Gram matrix)
        :param weight_vgg: weight of VGG extracted features (should be add up to 1.0)
        """
        super(CoarseLoss, self).__init__()
        if weight_vgg is None:
            weight_vgg = [0.5, 0.5, 0.5, 0.5, 0.5]
        self.w1 = w1
        self.w2 = w2
        self.l1 = nn.L1Loss(reduction='mean')
        self.l2 = nn.MSELoss(reduction='sum')
        # https://github.com/PatWie/tensorflow-recipes/blob/33962bb45e81f3619bfa6a8aeae5556cc7534caf/EnhanceNet/enet_pat.py#L169

        self.weight_vgg = weight_vgg
        self.vgg16_bn = vgg16_bn(pretrained=True).eval()

    # reference: https://github.com/pytorch/tutorials/blob/master/advanced_source/neural_style_tutorial.py
    @staticmethod
    def gram_matrix(mat):
        """
        Return Gram matrix

        :param mat: A matrix  (a=batch size(=1), b=number of feature maps,
        (c,d)=dimensions of a f. map (N=c*d))
        :return: Normalized Gram matrix
        """
        a, b, c, d = mat.size()
        features = mat.view(a * b, c * d)
        gram = torch.mm(features, features.t())
        return gram.div(a * b * c * d)

    def forward(self, y, y_pred):
        y_vgg = self.vgg16_bn(y)
        y_pred_vgg = self.vgg16_bn(y_pred)
        loss_vgg = [self.l2(self.gram_matrix(ly), self.gram_matrix(lp)) for ly, lp in zip(y_vgg, y_pred_vgg)]

        loss = self.w1 * self.l1(y, y_pred) + \
               self.w2 * np.dot(loss_vgg, self.weight_vgg)
        return loss

Actually, I have been stuck here for months and I have tried many different implementation of UNet, Learning rate, epoch size, batch size, etc values but this is the best result I have got.

PS:

  1. I am trying to train model for 20 epochs on only 9 images to make sure it overfits!‘’
  2. Even when I feed ground truth as input, it still generates outputs with dot artifacts.

I am sincerely so appreciated for any kind of help such as changing structure or how to track issues.

Hi, as it has been unanswered for 2 months, how is it going with this? I cannot tell, what’s the issue here, but one general suggestion that has helped me tremendously in vision research is very simple: write unit tests, even for code that you think is plain simple and works OK. Very, very often the problem is not with a training code, but with data preprocessing, validation, evaluation etc. Definitely write unit tests for a loss function to see if it really does what you think it does :slight_smile:

All image and 3D transformation and normalisations are very hard to comprehend fully numerically, and it is very easy to make a small trivial error that renders the results unusable.

Also, plotting and visualising data after each processing step is also essential for debugging.

Yes, actually it is about 5 months I am stuck and actually I abandoned it since last month. I always try to write unit tests but it is a deep learning model, how can I write one? I have googled for DevOps tools for deep learning but they are ambiguous somehow. In the last months I was trying to learn Facebook’s Visdom and a bunch of other tools to make sure how the model is working and I can easily say I have tried more than 1000 changes to make it work but I do not know why it does not work.

In the end, my main problem was that because I am student and I do not have credit card so I could not use any cloud to use GPU and I do not have GPU on my own laptop so testing different approaches were also so frustrating and time consuming and that’s why I have abandoned it.

Thanks for your answer.

Normally you do not need to unit-test the model itself. You can use assert statements to check that the outputs of different layers are as expected, but normally any error there would cause the next layer to fail. Just make sure that the architecture of the network is implemented as in article.

I was referring to unit testing all the other code, 95% of cases you will get such errors from:

  1. erroneous ground truth data
  2. erroneous loss function
  3. erroneous conversion (i.e., model does the job well but you have an error on using the result: swapped some axis or channel, mapped something wrongly etc)

Randomly trying out 1000 things is not a good idea, rigorously testing each and every line of code may look tedious but actually is much faster and leads to better results.

2 Likes

Thanks for your complete explanation.
I will try to devise a plan to check the implementation step by step.

Yes, of course, A/B testing to get a result in deep learning is not a good idea at all even though many people think deep learning is a black box method and I have always tried to avoid doing it.
One thing came to my mind a few weeks ago was that I should learn Keras or FastAI, a high level framework that no longer calling low-level functions is needed. I may used some of those functions inappropriately in my custom definition of classes although I have custom functions manually.

Thank you again for your time.