Unet Multiclassification : Output of my prediction is not what i expected

I am building a unet architecture for multiclass classification. The goal is to recognize food on images (eggs, tomato, cheese, etc). I am using Python and Pytorch as library. I have run my train code with my unet architecture but the output prediction is not what i expected. The output show only one class of value 1. Can someone help me with this ? Thanks a lot

the output :
output

Here is my unet architecture :

def double_conv(in_channels, out_channels):
    conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True)
    )
    return conv


class UNet(nn.Module):
    def __init__(self, num_classes):
        super(UNet, self).__init__()

        # definition of max pooling :

        self.max_pool_2x2_1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.max_pool_2x2_2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.max_pool_2x2_3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.max_pool_2x2_4 = nn.MaxPool2d(kernel_size=2, stride=2)
        # self.max_pool_2x2_5 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.down_conv_1 = double_conv(3, 16)
        self.down_conv_2 = double_conv(16, 32)
        self.down_conv_3 = double_conv(32, 64)
        self.down_conv_4 = double_conv(64, 128)
        # self.down_conv_5 = double_conv(128, 256)
        # self.down_conv_6 = double_conv(256, 512)

        # self.up_conv_trans1 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
        # self.up_conv_1 = double_conv(512, 256)

        # self.up_conv_trans2 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
        # self.up_conv_2 = double_conv(256, 128)

        self.up_conv_trans3 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
        self.up_conv_3 = double_conv(128, 64)

        self.up_conv_trans4 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=2, stride=2)
        self.up_conv_4 = double_conv(64, 32)

        self.up_conv_trans5 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=2, stride=2)
        self.up_conv_5 = double_conv(32, 16)

        self.out = nn.Conv2d(in_channels=16, out_channels=num_classes,  kernel_size=1)

    def forward(self, image):

        out_conv1 = self.down_conv_1(image)
        out_pool1 = self.max_pool_2x2_1(out_conv1)
        out_conv2 = self.down_conv_2(out_pool1)
        out_pool2 = self.max_pool_2x2_2(out_conv2)
        out_conv3 = self.down_conv_3(out_pool2)
        out_pool3 = self.max_pool_2x2_3(out_conv3)
        out_conv4 = self.down_conv_4(out_pool3)
        # out_pool4 = self.max_pool_2x2_4(out_conv4)
        # out_conv5 = self.down_conv_5(out_pool4)

        # decoder part

        # out_up_conv = self.up_conv_trans2(out_conv5)
        # out_up_conv = self.up_conv_2(torch.cat([out_up_conv, out_conv4], 1))

        out_up_conv = self.up_conv_trans3(out_conv4)
        out_up_conv = self.up_conv_3(torch.cat([out_up_conv, out_conv3], 1))

        out_up_conv = self.up_conv_trans4(out_up_conv)
        out_up_conv = self.up_conv_4(torch.cat([out_up_conv, out_conv2], 1))

        out_up_conv = self.up_conv_trans5(out_up_conv)
        out_up_conv = self.up_conv_5(torch.cat([out_up_conv, out_conv1], 1))

        out_up_conv = self.out(out_up_conv)
        return out_up_conv

and here is my training code :slight_smile:

import torch.nn.functional as F
import os
import torch.optim as optim
import numpy as np
from UNET_19classes import *
import random
from data_augmentation import transform


def check_nan(model, input):
    output = model(input)
    return torch.isnan(output).any()

# Sauvegarder le meilleur modèle
def save_model(epoch, model, optimizer, loss):
    print(f"save final model")
    torch.save({
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'loss': loss,
    }, f'C:/Users/ne/Desktop/food-recognition/modele_sauvegarder/modele_{epoch}_{loss}.pth')

# Définition des classes
classes = ['background', 'bun', 'laitue', 'tomates', 'cheddar', 'hamburger',
    'sauce', 'oignons', 'cornichon', 'fromage', 'poulet',
    'fish', 'bacon', 'oeuf', 'oignon cuit', 'pomme de terre / frite',
    'avocat', 'crevette', 'poulet pané/tenders', 'champignon']


# Convertir la liste de classes en un objet PyTorch Tensor
class_to_idx = {classes[i]: i for i in range(len(classes))}
idx_to_class = {i: classes[i] for i in range(len(classes))}
class_to_idx_tensor = torch.tensor(list(class_to_idx.values()))


# Définition du modèle UNet
model = UNet(num_classes=20)

# Charger les données d'entraînement et de test
train_source = np.load("images_burger_train_set.npy")
train_source = train_source[..., ::-1]
# train_source = train_source / 255.0
test_source = np.load("images_burger_test_set.npy")
train_target = np.load("mapped_masks.npy")
test_target = np.load("mapped_masks_test.npy")

# Mise aux dimensions correctes [N, C, H, W]
train_target = np.expand_dims(train_target, axis=3)
test_target = np.expand_dims(test_target, axis=3)
arr_zeros = np.zeros((156, 256, 256, 2))
arr_zeros2 = np.zeros((45, 256, 256, 2))
train_target = np.concatenate((train_target, arr_zeros), axis=-1)
test_target = np.concatenate((test_target, arr_zeros2), axis=-1)

# Définition de la fonction perte et de l'optimiseur
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Paramètres généraux
num_epochs = 30

correct = 0
total = 0

# Entraînement sur CUDA du modèle
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.cuda()



# Boucle d'entraînement principale
for epoch in range(num_epochs):
    # Phase d'entraînement
    model.train()
    batch_size = 12
    train_acc = 0.0

    # Paramètres d'entraînement
    number_slice_trainset = list(range(156))
    index_train = 0
    number_batch_train = 156 // batch_size
    random.shuffle(number_slice_trainset)
    num_correct_train = 0
    train_loss_array = []
    train_accuracy_array = np.array([0])
    train_loss = 0.0

    # Paramètres de test
    test_epoch_loss = 0
    number_slice_testset = list(range(45))
    index_test = 0
    number_batch_test = 45 // batch_size
    random.shuffle(number_slice_testset)
    num_correct_test = 0
    test_loss_array = []
    test_accuracy_array = np.array([0])
    best_test_loss = np.finfo(np.float32).max
    val_loss = 0.0

    # Boucle d'entraînement sur les lots/batch
    for data in range(number_batch_train):
        # Préparation des données
        batch_train_source_numpy = train_source[number_slice_trainset[index_train:index_train + batch_size], :, :]
        batch_train_source_numpy = np.transpose(batch_train_source_numpy, (0, 3, 1, 2))
        batch_train_source_numpy = batch_train_source_numpy/255.0
        batch_train_target_numpy = train_target[number_slice_trainset[index_train:index_train + batch_size], :, :]
        batch_train_target_numpy = np.swapaxes(batch_train_target_numpy, 1, 3)
        batch_train_target_numpy = np.swapaxes(batch_train_target_numpy, 2, 3)

        batch_train_source_cuda = torch.tensor(batch_train_source_numpy).to(device=device, dtype=torch.float32)
        batch_train_target_cuda = torch.tensor(batch_train_target_numpy).to(device=device, dtype=torch.float32)
        batch_train_target_cuda = batch_train_target_cuda.cuda().float()
        batch_train_target_cuda = torch.argmax(batch_train_target_cuda, dim=1)
        result = transform(batch_train_source_cuda, batch_train_target_cuda)

        batch_train_logits_cuda = model(result[0])

        batch_prediction_cuda = torch.softmax(batch_train_logits_cuda, dim=1)
        preds = batch_prediction_cuda.cpu().detach().numpy()
        # Conversion des probabilités en indice de classe
        preds_idx = np.argmax(preds, axis=1)

        # Remplacement des noms de classes par les indices correspondants
        names_for_each_pixel = preds_idx.astype(np.uint8)

        loss = loss_function(batch_train_logits_cuda, batch_train_target_cuda)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_loss += loss.item()
        index_train += batch_size

        preds = F.softmax(batch_train_logits_cuda, dim=1).argmax(dim=1)
        num_correct_train += (preds == batch_train_target_cuda).sum()

    train_loss /= number_batch_train
    print(f"epoch {epoch + 1}, train loss: {train_loss}")
    train_acc = num_correct_train / ((torch.numel(batch_train_logits_cuda)) * number_batch_train)


    # Test du modèle sur les données de validation
    model.eval()
    val_acc = 0.0

    # Boucle de validation sur les lots
    with torch.no_grad():
        for data in range(number_batch_test):
            batch_test_source_numpy = test_source[number_slice_testset[index_test:index_test + batch_size], :, :]
            batch_test_source_numpy = np.swapaxes(batch_test_source_numpy, 1, 3)

            batch_test_target_numpy = test_target[number_slice_testset[index_test:index_test + batch_size], :, :]
            batch_test_target_numpy = np.swapaxes(batch_test_target_numpy, 1, 3)

            batch_test_source_cuda = torch.tensor(batch_test_source_numpy).to(device=device, dtype=torch.float32)
            batch_test_target_cuda = torch.tensor(batch_test_target_numpy).to(device=device, dtype=torch.float32)
            batch_test_target_cuda = torch.argmax(batch_test_target_cuda, dim=1)

            outputs = model(batch_test_source_cuda)

            loss = loss_function(outputs, batch_test_target_cuda)
            test_epoch_loss += loss
            index_test += batch_size

            val_loss += loss.item()
            preds = F.softmax(outputs, dim=1).argmax(dim=1)
            acc = (preds == batch_test_target_cuda).float().mean()
            val_acc += acc.item()

    test_epoch_loss /= number_batch_test
    print(f"epoch {epoch + 1}, test loss: {test_epoch_loss}")
    test_loss_array = np.append(test_loss_array, test_epoch_loss.item())
    np.save(os.path.join(numpy_array_path, 'test_loss_array'), test_loss_array)

Hi Nestlee!

You are making predictions for out_channels = num_classes = 20, but
up_conv_5 has already reduced your number of channels to 16. I wouldn’t
expect this to fully explain your issue, but I could see it hurting the performance
of your model in that out is not being passed enough per-pixel information
(only 16 channels) to make full predictions for all 20 classes. Maybe you could
reduce your numbers of channels more slowly and try to feed something like
48 channels into out.

You don’t show us what transform() actually is, but it looks like it is returning,
perhaps, a tuple, so I speculate that it is transforming your train_source and
train_target together in a consistent way. You pass the transformed “source”
data to your model to get your predicted “logits.”

But you then pass your untransformed target data to loss_function().
Depending on what transform() does, your untransformed targets might
no longer be consistent with your “source” and predictions, which could
definitely screw things up.

As an aside: softmax() doesn’t change the order of the values, so using
softmax() won’t affect the result of argmax(). Therefore, for the purpose
of getting integer class-label predictions from argmax(), you can safely
leave softmax() out. (Of course, leaving softmax() in won’t change the
result, so it’s okay either way.)

For debugging, I have the following suggestions:

You’ve displayed an “image” of your post-argmax() class labels. You should
also look at your (pre-argmax()) “logit” “image.” Are all of per-pixel “logit”
values the same, or do they just map to the same class label?

Visualize some prediction “images” before you’ve trained your model at all
(using the model’s random initialization). The results should look random,
and certainly shouldn’t predict just a single class.

Also look at some target images. Do they make sense or do all (or almost all)
of the target pixels have the same value?

Can you overfit a small batch (of say four or eight samples) of source / target
pairs, leaving transform() out? If you train a lot on the same small batch,
you should be able to train your model to predict (by “memorization”) the
correct target values.

When you train (either for real or in your overfitting test) do your predictions
and model parameters change?

Best.

K. Frank

Thank you for your reply ! @KFrank

I found the problem but cannot see any solution with my knowledge.
The problem come from my mask array, whenever I expand the dimension to have (N, C, W, H) (45, 3, 256, 256), the value of pixels are set to 1, so my mask appear with the correct contour but the pixels have a value of 1.
test-target

Do you have any suggestion to expand the dimensions of my mask array without impacting the pixels value ?

I tried to add a 4th dimension without specifiying that it should be (3). So instead of (45, 3, 256, 256), I have (45, 1, 256, 256) and the mask appears at it’s original colors

Hi Nestlee!

The short story is that you don’t want to expand your mask dimensions.

Based on my assumptions about your use case, you want
batch_train_target_cuda (the target you pass to your CrossEntropyLoss
loss_function) to be a LongTensor with shape [45, 256, 256]. (This will
be for a batch size of 45 and H = W = 256.) Note that this has no channels or
classes dimension. This tensor’s values should be your ground-truth integer
class labels that run from 0 to num_classes - 1 = 19.

I speculate that your train_target numpy array already starts out in this form.
If so, all you should have to do is convert it to a pytorch LongTensor. Do note
that because of how CrossEntropyLoss works, it is important that the target
you pass in is a LongTensor and not, for example, a FloatTensor that happens
to consist of integer values.

Best.

K. Frank

Hi @KFrank

Why do we have to use LongTensor for this case ? and what about the shape of image tensor (batch_train_source_cuda) ? Does it have to be (N, W, H) too?
At the end, how will we know which class each pixel belongs to if we don’t put in an image tensor with a defined number of channels?
If batch_train_target_cuda is (N, W, H) and batch_train_source_cuda is (N, C, W, H), I think that it will not work.

Thank you,
Kind Regards,

Hi Nestlee!

Read the documentation for CrossEntropyLoss carefully (and see below).

No. For your U-Net implementation (which is typical) the image you input to
your model should have a channels dimension, in your case, three channels.
This is what’s expected by the first layer of your model which is a Conv2.

In your case, it looks like you are using RGB images – three channels – but
even if you were working with single-channel images (e.g., gray-scale), you
would still need a channels dimension, even though it would be a “singleton”
dimension (that is of size 1).

It will work as follows:

Your CrossEntropyLoss loss function takes an input tensor (the output
of your model – that is the predictions your model makes) of shape
[N, C, W, H] with, in your case, C = num_classes = 20. It also takes
a target tensor – your labels – with shape [N, W, H], specifically with
no channels / classes dimension. But how then does your target tensor
know how many classes you have? Because its values are integer class
labels in the range [0, num_classes - 1 = 19]. This will have type long,
that is, will be a LongTensor.

Your model with take as input an image of shape [N, channels = 3, W, H]
and output a prediction tensor of shape [N, num_classes = 20, W, H].
These tensors will have type float, that is, will be FloatTensors.

It is CrossEntropyLoss that matches up a FloatTensor input of
shape [N, num_classes, W, H] with a LongTensor target of shape
[N, W, H].*

As an aside, we usually call the last two dimensions of these tensors H,
for height, and W, for width, rather than W and H. But that doesn’t matter
for the actual computation – it’s just terminology.

*) CrossEntropyLoss works in two modes: It can take a target of “hard”
integer class labels of shape [N, W, H] (no num_classes dimension) of
type LongTensor or one of “soft” probabilistic class probabilities of shape
[N, num_classes, W, H] of type FloatTensor. It decides which mode
to use based on the type and shape of the target tensor you pass into it.
My assumption is that you are using a target that consists of “hard” integer
class labels.

Best.

K. Frank

Hi @KFrank !

Thank you for your explaination !

I modified my code so that the mask appear with the correct classes.

The problem that occurs actually is that the value of test loss gets higher. the parameters I chose are the following :

  • loss function : crossentropy loss
  • optimizer : adam
  • lr : 1e-5
  • batch size : 8 (or 12)

why does it increase ? I checked my database (mask image match)

Hi Nestlee!

Let me assume that as you train the loss computed for your training set goes
down (and maybe some performance metric such as accuracy goes up), while
the loss computed for your test set goes up (and maybe some performance
metric goes down).

If so, you are most likely seeing overfitting, where your network gets trained to
“memorize” your training set and give good predictions for it, but it doesn’t really
“learn” to perform well on the data it wasn’t trained on.

Best.

K. Frank