I have got a task (video analysis) where I have to divide the last fully connected layer of (1+2)D ResNet into multiple chunks and perform classification loss calculation on each individual chunk using CrossEntropyLoss. With more context, I’ve reshaped my model’s last layer (FC) to generate the output of shape (64, 4, 21) where 64 is the batch size, 4 is the number of chunks and 21 is the number of classes to perform classification. So basically my model is computing 4 class values for each video segment. I am computing loss in the following manner:

def calculate_loss(output, target, criterion):
loss = 0
for i in range(output.size(1)):
loss += criterion(output[:, i, :], target[:,i])
return loss # / output.size(1)

Where I am iterating over each chunk and calculate the classification loss. My model trains well for a few batches (150/2500) but then loss and accuracy both stagnate. And when I check the output of each chunk it produces homogeneous classification values (almost same class), regardless of the target having different classes in each chunk.

As an aside, you can use the “K-dimensional case” feature of CrossEntropyLoss (if you’re using a recent enough version of
pytorch) to eliminate your for loop.

should yield the same result as your calculate_loss function.

Yes, there is nothing in principle wrong with what you are doing.

The same as usual. output is a differentiable function of your model
parameters, the loss in your loop for each chunk is a differentiable
function of output, and summing loss over the four chunks is
differentiable, so the gradients backpropagate through your summed loss all the way back to your model parameters.

Many things could cause this, but it’s not caused directly by your
“chunked-loss” function. There could be some bug in your model,
or it could just be that the problem you’re working on is hard (or
impossible), and the way your training data and model interact
makes training difficult.

Thanks for reply @KFrank. I looked at data more carefully and found some anomalies in data itself so that might be the reason for poor training. Thanks again for clearing my doubts.