How to use CrossEntropyLoss on a transformer model with variable sequence length and constant batch size = 1?


I’m trying to train a model for predicting protein properties. The model takes as input a whole protein sequence (max_seq_len = 1000), creates an embedding vector for every sequence element and then uses a linear layer to create vector with 2 elements to classify each sequence element into 2 classes. To make use of a variable sequence length and also because of gpu memory limitation, the model takes in a single protein as a batch, i.e. batch_size is always 1. The output of the model is [batch_size, sequence_length, num_classes] (e.g. [1, 376, 2] or [1, 123, 2].). The target tensor has the shape [batch_size, sequence_length] with each element of the sequence_length dimension being an integer indicating the class, i.e. 0 or 1.
How do I compute the loss correclty with CrossEntropyLoss in this case? Is CrossEntropyLoss the correct loss function for this case?

Thank you!

Yes, nn.CrossEntropyLoss sounds like the correct loss function for your multi-class sequence classification. To calculate the loss you would have to permute the model output as this loss function expects model output in the shape [batch_size, nb_classes, *] (where * would be the seq_len in your example) via output = output.permute(0, 2, 1).

Since you are dealing with a two-class classification you could also check nn.BCEWithLogitsLoss but would need to make sure the model outputs a single logit for each sample in the sequence.