Multiclass Classification with sub-classes

Hello folks

I have a multi-class classification problem, where the number of classes is, say 10000 (or a very large number). The usual way of calculating loss is to use first take the softmax of 10000-dim raw scores (to turn them into a probability) and then take negative log i.e. the multi-class cross entropy loss.

Learning a 10000-way classifier obviously needs a lot of training samples. Now if I know that each training sample belongs to a (different) subset of these 10000 classes, how can I use this information to my advantage. Essentially it would involve taking a softmax over only that subset. This is different from the class imbalance problem because each sample belongs to a different subset.

I thought of using a boolean mask but that would skew the computation of denominator of softmax by having terms of exp(0).

Any pointers will be helpful. Oh and I don’t want to lose too much on speed :slight_smile:

Thanks!