Modify predicted class label to category label before computing the loss

I am doing semantic segmentation task and I already train a model with a dataset that provides class label.

Now, I have a new dataset that only provide category label rather than the class labels I used to train. But, to utilize the pretrained model, I have to either adapt the prediction layer channels from number of classes to number of category; or, I map the predicted class labels to category labels, then use this to compute cross entropy loss as before.

For me, the second option sounds optimal, because all layers in the pretrained model can be applied.

But, if I modify the predicted label, what would be the consequence to autograd?


I’m citing @KFrank here:

The number-one rule is that your output means whatever you train it to mean…

I’m not sure what kind of classes and categories you are using and how similar they are.
E.g. if you remap “husky” and “bulldog” to a “dog” class, your model might still work fine.
Depending how the remapping is done, you might need to retrain the model.

Basically 1 category is like a collection of class labels, who belongs to this category. For example flat as a category includes road, road line, sidewalk. The pretrained model is trained on class level, but what I want to fine tune is to predict category label.

I just want to know modify the prediction before computing the loss, whether this is a legal operation, because I feel like manually modification is not differentiable. Maybe I can add a additional mappling layer right after the pretrained model, that learn how to map class label to category?

Since the mapping is “hard-coded” you could keep the model with the class outputs and later use the prediction to map it to the corresponding category.

However, another way would be to let the model learn to combine the classes to categories.
To do so you could add another layer, get the class predictions as its input, and output the categories.
This might allow the model to learn that e.g. 5 “medium predictions” of flat classes are a stronger signal than a single higher prediction of a non-flat class, assuming that the target is the flat category.