Incredibly High CrossEntropyLoss in Sequence-to-Sequence Generation

I’m trying to do SMILES chemical representation prediction from a large dataset (Around 5M Samples) to teach it do predict another downstream task. The model’s part responsible for generating the data is a decoder embedding layer that roughly looks like this:

self.decoder_embedding = nn.Embedding(len(tokenizer), hidden_size)
decoder_layer = nn.TransformerDecoderLayer(
    d_model=hidden_size,
    nhead=heads,
    dim_feedforward=hidden_size,
    dropout=dropout,
    batch_first=True,
    norm_first=True
)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=2)
self.smiles_generation_head = nn.Linear(hidden_size, 125)

The inputs to the model includes smiles_tokens which have been randomly masked of shape [batch_size, 125] and a padding attention mask. The output is simply pretext_predictions of shape [batch_size, 125].

def mask_tokens(inputs, tokenizer, mask_prob=0.15):
    masked_inputs = inputs.clone()
    mask_token_id = tokenizer.mask_token_id

    # Mask a random selection of tokens
    random_mask = torch.rand(masked_inputs.shape) < mask_prob
    masked_inputs[random_mask] = mask_token_id

    return masked_inputs

masked_smiles_tokens = mask_tokens(non_share_smiles_tokens, tokenizer)

smiles_tokens = model(masked_smiles_tokens, *features)

When this is passed into the loss function, it starts at an incredibly high loss value (roughly 10K and above). When the batch size is increased, the loss value also doubles and sometimes triples to around 30K or above.

smiles_tokens = nn.ConstantPad1d(
    (0, 125 - smiles_tokens.shape[1]),
    0
)(smiles_tokens).float()
padding_mask = smiles_tokens != 0
loss = loss_fn(
    pretext_predictions[padding_mask],
    smiles_tokens[padding_mask],
)

This is the loss function definition:

def loss_fn(inputs, targets):
    ce_criterion = nn.CrossEntropyLoss(reduction='mean')
    ce_loss = ce_criterion(inputs, targets)
    return ce_loss

What is causing this high loss value? I tired normalizing my input features apart from the SMILES token indices and that doesn’t seem to solve the issue. I noticed that when I create random tensors like this:

import torch
import torch.nn as nn
import random

# Example tensors (replace with your actual data)
predictions = torch.randn(16, 50).softmax(dim=1)
targets = torch.randint(0, 100, (16, 50)).float()
mask = torch.randint(0, 2, (16, 50))

# Create the loss function
loss_fn = nn.CrossEntropyLoss() 

# Calculate the loss
loss = loss_fn(predictions[mask], targets[mask])

print(loss)

It also result in huge loss values.

tensor(9981.9561)

My ultimate aim is a BCE task and the loss from BCE is very small (0 to 1) compared to the CE loss. This makes it difficult to asses whether the model is making any progress. What should I do? Do I simply just find another loss function? NLLLoss also doesn’t seem to being doing well but CE Loss uses that under the hood so I’m guessing it stems from that?

Note: pretext_predictions are logits, and smiles_tokens are indices for the tokenizer vocabulary of size 37. The learning rate is 1e-3 using Adam and model size is only 5M parameters.

So I just noticed something unusual. I tired implementing CrossEntopyLoss manually myself and got different resuls to the PyTorch Native nn.CrossEntropyLoss.

Example to reproduce:

import torch
import torch.nn as nn
import random

predictions = torch.tensor([
    2.6806,  1.6770, -0.5246,  1.0230,  1.9406,  1.5417, -2.1595,  3.4809,
    -0.3840, -1.7314, -2.3532,  1.3828, -0.6205,  0.5228, -2.8180, -4.2407,
    0.1465, -0.9347,  0.3158,  1.5887, -1.2187,  0.2900,  1.4955,  1.9007,
    1.9633,  2.2591, -0.0227, -0.6740,  0.4968,  1.4456, -1.4887,  1.4188,
    0.7862,  1.3885,  4.4571, -0.3494,  2.4859, -1.0327,  0.8690,  0.5135,
    0.3481, -0.2618,  0.6049, -0.6884, -1.1072,  0.1194, -0.3943, -2.6933,
    -0.5439,  0.0940, -0.0691,  2.3558, -1.2147, -3.9804, -3.6275, -2.4481,
    -2.1119,  0.2207,  3.5445, -1.2965, -1.4442,  0.8695, -1.8092, -0.2379,
    0.3036,  2.5865,  2.4879, -0.1054, -2.6106,  2.5319, -0.7971,  3.0985,
    3.2366,  0.1034,  1.5961,  0.6682,  2.4263, -2.0714,  3.0258, -0.9155,
    -1.9036, -2.1316,  1.5482,  0.4479,  0.9619, -2.8073, -4.5360,  0.5597,
    -1.1316,  1.0735,  1.7407, -1.2951, -0.5784,  0.6017,  2.1548,  2.2024,
    2.6517,  0.3819, -0.7844,  0.4975,  0.8481, -0.8627,  1.0941,  0.4227,
    1.7440,  4.9244, -0.6822,  0.5583, -0.8062,  1.1259,  0.9110,  0.4485,
    -1.0525,  0.7695, -0.7216, -1.7309, -0.1571, -0.6827, -3.0033,  0.0365,
    -0.0286, -0.5585,  3.1764, -1.0799, -5.4084, -3.2902, -2.4649, -1.8878,
    0.9609,  3.8043, -0.6734, -1.5818,  0.4709, -1.7792,  0.3450,  0.8633,
    2.4463,  3.1707
])
targets = torch.tensor([
    12., 19., 22., 16., 17., 23., 18., 15., 20., 15., 15., 17., 31., 15.,
    21., 15., 15., 15., 15., 17., 19., 16., 26., 16., 16., 16., 16., 26.,
    18., 15., 21., 18., 15., 25., 15., 20., 23., 20., 16., 16., 16., 16.,
    17., 23., 16., 17., 22., 19., 18., 16., 21., 16., 16., 17., 16., 26.,
    16., 16., 19., 16., 16., 26., 18., 22., 23., 19., 21., 18., 16., 20.,
    13., 12., 16., 16., 19., 16., 17., 22., 19., 18., 15., 20., 25., 44.,
    25., 15., 20., 23., 15., 20., 25., 15., 17., 23., 15., 21., 15., 25.,
    15., 15., 17., 27., 18., 15., 21., 18., 25., 15., 17., 23., 16., 17.,
    16., 15., 21., 15., 15., 15., 17., 16., 18., 15., 17., 16., 18., 15.,
    21., 18., 16., 17., 22., 19., 18., 23., 18., 25., 20., 13.
])


criterion = nn.CrossEntropyLoss()
loss = criterion(predictions, targets.float())
print("Native CrossEntropyLoss:", loss)

criterion = nn.NLLLoss()
predictions = torch.log(torch.softmax(predictions, dim=0))
loss = criterion(predictions, targets.long())
print("Manual CrossEntropyLoss:", loss)
Native CrossEntropyLoss: tensor(16612.7012)
Manual CrossEntropyLoss: tensor(7.2559)

Is my implementation wrong?

A few different issues are in your code:

  • nn.CrossEntropyLoss expects raw logits, so remove the softmax applied on the predictions.
  • nn.CrossEntropyLoss expects a target tensor containing class indices in the range [0, 100] as a LongTensor in the shape [batch_size, *] or a “soft” target containing floating point values in the range [0, 1] in the same shape as your model output, so [batch_size, nb_classes] for a multi-class classification. Currently you are sampling from the range [0, 100] (which is not a valid class index based on the shape of predictions) and are using a FloatTensor` (so you are using “probabilities” > 1.).
  • The mask is changing the shape of the model output and target to torch.Size([16, 50, 50]) and I don’t think you want that.

Is my custom implementation of CrossEntropyLoss also wrong?

Hi, to add further information :

You can think of CrossEntropyLoss this way :

Logits → Probabilities (Sofmax) → Get the probability of the correct class → Negative Log-likehood → Get the mean → Loss

You can refer to this implementation :

import torch
import torch.nn as nn

torch.manual_seed(0)
BATCH_SIZE = 16
NO_CLASSES = 50
logits = torch.randn(BATCH_SIZE, NO_CLASSES)
true_prediction = torch.randint(0, NO_CLASSES, (1, BATCH_SIZE)).squeeze(0)
loss_fn = nn.CrossEntropyLoss() 
loss = loss_fn(logits, true_prediction)

Here is the manual version of it:

torch.manual_seed(0)

BATCH_SIZE = 16
NO_CLASSES = 50
logits = torch.randn(BATCH_SIZE, NO_CLASSES)
true_prediction = torch.randint(0, NO_CLASSES, (1, BATCH_SIZE)).squeeze(0)
probabilities = logits.softmax(dim=-1)
nll = -probabilities[torch.arange(BATCH_SIZE),true_prediction].log().mean()

Hope this code above helps you to understand it. If you find it hard to understand, you may refer to this tutorial : Andrej Karpathy - Building makemore Part 2: MLP (timestamp 32:52)

Given 1D tensors like this, what is a good loss function? targets is a set of indices.

I wouldn’t focus on the used shapes, but on the use case and then pick the right loss function. E.g. if you are working on a multi-class classification, use nn.CrossEntropyLoss, and fix the tensor shapes and dtypes as explained in my previous post.

Thank you so much. I figured out I need an output shape of [batch_size, classes, seq_len] and target needs [batch_size, seq_len].

Here is a reproducible example.

torch.manual_seed(27)

predictions = torch.randn(2, 5, 5)
predictions = nn.ConstantPad1d(
    (0, 10 - predictions.shape[2]),
    0
)(predictions).float()

targets = torch.randint(0, 5, (2, 5)).float()
targets = nn.ConstantPad1d(
    (0, 10 - targets.shape[1]),
    0
)(targets).long()

# Create the loss function
loss_fn = nn.CrossEntropyLoss(ignore_index=0) 

# Calculate the loss
loss = loss_fn(predictions, targets)

print(loss)

The essence is to make sure the output has a “number of classes” dimension.