Multi-Class Cross Entropy Loss function implementation in PyTorch

I’m trying to implement a multi-class cross entropy loss function in pytorch, for a 10 class semantic segmentation problem. The shape of the predictions and labels are both [4, 10, 256, 256] where 4 is the batch size, 10 the number of channels, 256x256 the height and width of the images.

The following implementation in numpy works, but I’m having difficulty trying to get a pure PyTorch implementation working, or even a hybrid (less performant one) working whereby I transfer the tensor out from PyTorch, calculate the loss in numpy, and return the tensor back to PyTorch.

Here is the working numpy version:

import numpy as np

def multi_class_cross_entropy_loss(predictions, labels):
    Calculate multi-class cross entropy loss for every pixel in an image, for every image in a batch.

    In the implementation,
    - the first sum is over all classes,
    - the second sum is over all rows of the image,
    - the third sum is over all columns of the image
    - the last mean is over the batch of images.

    :param predictions: Output prediction of the neural network.
    :param labels: Correct labels.
    :return: Computed multi-class cross entropy loss.

    loss = -np.mean(np.sum(np.sum(np.sum(labels * np.log(predictions), axis=1), axis=1), axis=1))

    return loss

This is an attempt at implementing the same function in PyTorch:

def multi_class_cross_entropy_loss_torch(predictions, labels):
    Calculate multi-class cross entropy loss for every pixel in an image, for every image in a batch.

    In the implementation,
    - the first sum is over all classes,
    - the second sum is over all rows of the image,
    - the third sum is over all columns of the image
    - the last mean is over the batch of images.
    :param predictions: Output prediction of the neural network.
    :param labels: Correct labels.
    :return: Computed multi-class cross entropy loss.

    loss = -torch.mean(torch.sum(torch.sum(torch.sum(labels * torch.log(predictions), dim=1), dim=1), dim=1))
    return loss

for which I get the following error:

line 34, in multi_class_cross_entropy_loss_torch
    loss = -torch.mean(torch.sum(torch.sum(torch.sum(labels * torch.log(predictions), dim=1), dim=1), dim=1))
RuntimeError: Expected object of type torch.LongTensor but found type torch.FloatTensor for argument #2 'other'

Are labels of type torch.long?
If so, could you transform them to float with labels.float() and run it again?

Hi @ptrblck

Yes labels were of type torch.long. I converted them to float() from the output of the dataset, and I no longer get the type error.

1 Like

Hi @ptrblck

How should I go about converting following numpy expression to an equivalent one using torch.sum() or torch.view().sum()?

loss = -np.mean(np.sum(np.sum(np.sum(labels * np.log(predictions), axis=1), axis=1), axis=1))

There was a discussion here Sum / Mul over multiple axes but I’m still trying to figure out the syntax.

1 Like

You could try the following code:

batch_size = 4
-torch.mean(torch.sum(labels.view(batch_size, -1) * torch.log(preds.view(batch_size, -1)), dim=1))

That is compact, I’ll try it out.

What I came up was a simple one, just to get it working, one using just sum()

    n, c, h, w = predictions.size()
    nt, ct, ht, wt = labels.size()

    loss = (labels * torch.log(predictions)).sum(2).sum(2)  # for all pixels in an image
    loss = loss.sum(1)            # for all classes
    loss = -torch.mean(loss)  # for the entire batch

and the other using view().sum(), which was only different for the first line below:

    loss = (labels * torch.log(predictions)).view(n, c, -1).sum(2)  # for all pixels in an image
    loss = loss.sum(1)            # for all classes
    loss = -torch.mean(loss)  # for the entire batch

Hi @ptrblck

When I try to use the loss function in a train loop, the results come out as nan

loss = loss_fn(output_batch, labels_batch)"loss shape: {}, loss data: {}".format(loss.shape,

train_unet:train:87: loss shape: torch.Size([]), loss data: nan

I unit tested the implementation using some test matrices for predictions and labels, the implementation is okay.

Any idea what could cause this? It should be some large value at the start of training.

Are your predictions negative?
You could check if with

(output_batch < 0.0).any()

If so, torch.log will return tensor(nan.).

Hi @ptrblck

Yes predictions are negative, (output_batch < 0.0).any() returns 1

What should I do? The input image to a U-Net model is a set of RGB images, and training mask is a 10-channel mask corresponding to the 10 different classes that I want to segment in the image.

You could apply F.sigmoid on your predictions before passing it to the loss function.

However, maybe I misunderstood your use case.
You said you are implementing a multi-class cross entropy. From the shape of your targets it looks like you would like to use a multi-label classification, i.e. each pixel location might have more than one class. Is this right or is every pixel belonging to one class only?

Hi @ptrblck

I have a set of 3-ch RGB images from the Kaggle DSTL Satellite Imagery Feature detection dataset. In it is has labelled polygonal data for 10 classes, ranging from buildings to trees, roads, waterways and vehicles.

I generated a binary mask for each of the classes, where each pixel value corresponds to a 1 if it contains a building or a 0 if it does not. I’ve stacked the individual marks into a single 10-ch mask.

Both the 3-ch image and 10-ch mask have been converted from numpy to torch float tensors, in the Dataset class.

For the NN model, I am using a U-Net model from

and configured the U-Net model to output 10 channels:

class UNet(nn.Module):
    def __init__(self, conv_kernel=3,
                 pool_kernel=3, pool_stride=2,
                 repeat_blocks=2, n_filters=8,
                 batch_norm=True, dropout=0.1,
                 in_channels=3, out_channels=10,

I’m then passing the output_batch and labels_batch to the multi-class cross entropy loss function in my training loop, which is where I am getting the nan error.

This is the entire train loop:

def train(model, optimizer, loss_fn, dataloader, metrics, params):
    """Train the model on `num_steps` batches

        model: (torch.nn.Module) the neural network
        optimizer: (torch.optim) optimizer for parameters of model
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        dataloader: (DataLoader) a object that fetches training data
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        params: (Params) hyperparameters
        num_steps: (int) number of batches to train on, each of size params.batch_size

    # set model to training mode

    # summary for current training loop and a running average object for loss
    summ = []
    loss_avg = RunningAverage()

    # Use tqdm for progress bar
    with tqdm(total=len(dataloader)) as t:
        for i, (train_batch, labels_batch) in enumerate(dataloader):
            # move to GPU if available
            if params.cuda:
                train_batch, labels_batch = train_batch.cuda(async=True), labels_batch.cuda(async=True)
            # convert to torch Variables
            train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)

            # compute model output and loss
            output_batch = model(train_batch)
            #logger.debug("train output_batch.shape = {}. labels_batch.shape = {}".format(output_batch.shape, labels_batch.shape))

            # check if predictions are negative
  "negative predictions: {}".format((output_batch < 0.0).any()))

            # compute loss
            loss = loss_fn(output_batch, labels_batch)
            logger.debug("loss: {}".format(

            # clear previous gradients, compute gradients of all variables wrt loss

            # performs updates using calculated gradients

            # Evaluate summaries only once in a while
            if i % params.save_summary_steps == 0:
                # extract data from torch Variable, move to cpu, convert to numpy arrays
                output_batch =
                labels_batch =

                # compute all metrics on this batch
                summary_batch = {metric:metrics[metric](output_batch, labels_batch)
                                 for metric in metrics}
                summary_batch['loss'] =

            # update the average loss


    # compute mean of all metrics in summary
    metrics_mean = {metric:np.mean([x[metric] for x in summ]) for metric in summ[0]}
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items())"- Train metrics: " + metrics_string)

def train_and_evaluate(model, train_dataloader, val_dataloader, optimizer, loss_fn, metrics, params, model_dir,
    """Train the model and evaluate every epoch.

        model: (torch.nn.Module) the neural network
        train_dataloader: (DataLoader) a object that fetches training data
        val_dataloader: (DataLoader) a object that fetches validation data
        optimizer: (torch.optim) optimizer for parameters of model
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        params: (Params) hyperparameters
        model_dir: (string) directory containing config, weights and log
        restore_file: (string) optional- name of file to restore from (without its extension .pth.tar)

    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(model_dir, restore_file + '.pth.tar')"Restoring parameters from {}".format(restore_path))
        load_checkpoint(restore_path, model, optimizer)

    best_val_acc = 0.0

    for epoch in range(params.num_epochs):
        # Run one epoch"Epoch {}/{}".format(epoch + 1, params.num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        train(model, optimizer, loss_fn, train_dataloader, metrics=metrics, params=params)

        # Evaluate for one epoch on validation set
        val_metrics = evaluate(model, loss_fn, val_dataloader, metrics=metrics, params=params)

        # TODO: Fix TypeError: 'NoneType' object is not subscriptable
        val_acc = val_metrics['accuracy']
        is_best = val_acc>=best_val_acc

        # Save weights
        save_checkpoint({'epoch': epoch + 1,
                         'state_dict': model.state_dict(),
                         'optim_dict' : optimizer.state_dict()},

        # If best_eval, best_save_path
        if is_best:
  "- Found new best accuracy")
            best_val_acc = val_acc

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(model_dir, "metrics_val_best_weights.json")
            save_dict_to_json(val_metrics, best_json_path)

        # Save latest val metrics in a json file in the model directory
        last_json_path = os.path.join(model_dir, "metrics_val_last_weights.json")
        save_dict_to_json(val_metrics, last_json_path)

def main():
    # print some log messages"DSTL Satellite Imagery Feature Detection - Train U-Net Model")

    # load parameters
    # load parameters from configuration file
    params = Params('experiment/unet_model/params_3ch.yaml', ParameterFileType.YAML, ctx=None)

    # parameters
    logger.debug("parameters: \n{}\n".format(pformat(params.dict)))

    # use GPU if available
    params.cuda = torch.cuda.is_available()

    # Set the random seed for reproducible experiments
    if params.cuda:

    # dataset parameters, which includes download, input, output and mask generation parameters.
    dataset_params = params.dataset
    logger.debug("dataset parameters: \n{}\n".format(pformat(dataset_params)))

    # dataset"loading datasets...")
    train_set = DSTLSIFDDataset(dataset_params=dataset_params,

    dev_set   = DSTLSIFDDataset(dataset_params=dataset_params,

    # dataloader
    logger.debug("train dataloader, batch size: {}, num workers: {}, cuda: {}".format(

    train_dl = DataLoader(dataset=train_set,

    logger.debug("dev dataloader, batch size: {}, num workers: {}, cuda: {}".format(

    valid_dl = DataLoader(dataset=dev_set,
                          pin_memory=params.cuda)"- done.")

    # define the model and optimizer
    #model = UNet()
    model = UNet().cuda() if params.cuda else UNet()"using adam optimized with lr = {}".format(float(params.learning_rate)))
    optimizer = optim.Adam(model.parameters(), lr=float(params.learning_rate))

    # loss function
    loss_fn = multi_class_cross_entropy_loss  # nn.MSELoss()  # nn.L1Loss() # nn.CrossEntropyLoss()

    # maintain all metrics required in this dictionary- these are used in the training and evaluation loops
    metrics = {
        'accuracy': accuracy,
        # could add more metrics such as accuracy for each token type

    # train the model"Starting training for {} epoch(s)".format(params.num_epochs))

    data_dir = "data/"
    model_dir = "experiment/unet_model"


if __name__ == '__main__':

Ah ok, thanks for the info.
It looks like a standard segmentation task.

I would suggest to use nn.CrossEntropyLoss for your use case.
Have a look at the following code snippet:

n_class = 10
preds = torch.randn(4, n_class, 24, 24)
labels = torch.empty(4, 24, 24, dtype=torch.long).random_(n_class)

criterion = nn.CrossEntropyLoss()
loss = criterion(preds, labels)

You don’t have to save the target as a “one-hot encoded” tensor, but can just use the class indices for the criterion.

Maybe I misunderstood your question, but what is your goal of implementing the criterion manually?


Hi @ptrblck

I tried to use nn.CrossEntropyLoss() initially like so, with image set to float() and mask set to long() type:

    # loss function
    loss_fn = nn.CrossEntropyLoss() 

but it gave me this error:

  File "/tool/python/conda/env/gis36/lib/python3.6/site-packages/torch/nn/", line 1334, in nll_loss
    return torch._C._nn.nll_loss2d(input, target, weight, size_average, ignore_index, reduce)
RuntimeError: invalid argument 1: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4 at /opt/conda/conda-bld/pytorch_1524590031827/work/aten/src/THCUNN/generic/

I attempted to write my own cross entropy loss function, thinking that I had to, because of the shape of my label outputs.

Is there a way to adapt and use the existing nn.CrossEntropyLoss() to work with the shape of my prediction and labels [batch, channels, height, width] ?

I think I’m missing something quite fundamental here. Is it okay to have my label masks with 10-channels, with each channel corresponding to a class (e.g. buildings)?

In your code snippet, labels is of shape [batch, height, width], whereas my labels are of shape [batch, channel, height, width]:
labels = torch.empty(4, 24, 24, dtype=torch.long).random_(n_class)

Yes, sure. You just have to get rid of the channel dimension in your targets, since you don’t need them.
The targets are encoded as indices in the pixel positions. Have a look at my sample code. My target doesn’t have the channel dim and it’s working.
Could you try that?

Hi @ptrblck

I’m not sure I follow where to drop the channels. Is it at the time of generating the mask? and have the mask generated as [batch, width, height]?

If so, how should I encode and generate the mask?

I’m not sure how the targets are saved in the kaggle data.
Do you have target images with the “colors” encoding the classes?
Could you post a sample?

This is an example 3-ch RGB image:


and this is a binary mask for buildings (class 0). A pixel value is set to 1 if it belongs to a building, and a 0 if it doesnt.


There are 10 channels in the mask, corresponding to the different classes. The mask shape is therefore [batch=4, channels=10, height=256, width=256].

This example shows the mask overlaid with the RGB image, to check for registration:


Some of the original papers on using Fully Convolutional Networks (FCN) for semantic segmentation shown a per-pixel wise prediction layer with 21 channels in the pixelwise prediction image, for a 20 class pascal voc dataset.

So this is the part that I don’t quite understand, should the output mask be a 10 or 11-channel image if the number of classes are 10?

[Edit: The earlier paper that used FCN on the PASCAL VOC dataset had an additional channel for background pixels, which is why it had 21 channels. The DSTL dataset has 10 labelled classes. Therefore, a mask should be generated for the background plus 10 channels, means that the total number of channels for the DSTL dataset, when using a single U-Net model should be 11.]


I assumed that the color coding was done after the per-pixel wise prediction output, so that you take the detected pixels in each channel, color code it and generate a final colored mask.

Ok, so each channel corresponds to a class in your target images.
Now we have to get the class and just save the indices.
I created a small example where target is your target image and labels is the tensor we are using for the criterion:

target = torch.randn(1, 10, 24, 24)
target = F.softmax(target, dim=1) > 0.5
# target has now a 1 in channel c if the pixel location belongs
# to that class (like your target images)

labels = torch.argmax(target, dim=1)
# labels has now indices for the corresponding class
1 Like

Hi @ptrblck

I did some training yesterday and the NN seems to be learning, the loss keeps going down steadily and appears to be converging.

Epoch 1/100
100% 5/5 [00:05<00:00,  1.80s/it, loss=2.325]
100% 5/5 [00:05<00:00,  1.90s/it, loss=2.281]
100% 5/5 [00:05<00:00,  1.78s/it, loss=2.242]

100% 5/5 [00:04<00:00,  1.41s/it, loss=0.075]
100% 5/5 [00:05<00:00,  1.95s/it, loss=0.072]

Epoch 100/100
100% 5/5 [00:04<00:00,  1.50s/it, loss=0.070]

I’ll now need to look at the output masks and see if it is working correctly.

I must admit, I don’t quite follow what’s happening here. How does the class to index mapping work? I understand the part about taking the 10 classes and running it through the softmax to work as an activation for the 10 classes, but what happens to a batch of say [4, 10, 24, 24], do we just drop the batches when doing F.softmax(target, dim=1) > 0.5 ?

Don’t we somehow lose the correspondence between the loss per channel, when doing this, during back propagation? i.e. the loss is finally a single scalar value, but is the framework able to keep track of the loss per channel during back propagation?