Hi George!
Regardless of whether yours is a multi-label problem or not, probably your
best approach is to compute your loss just at the subclass level and not add
a superclass term to it. If you train to predict the correct subclasses (which
is what you should be trying to do), you will also automatically be training to
predict the correct superclasses. You haven’t offered any explanation of why
you need or want to add a superclass term.
(You’ve told us more-or-less nothing about your concrete use case,* so its
hard to offer useful advice as to what approach might be best.)
Having said that, let me answer your technical question in the single-label
context.
Your network will predict unnormalized log-probabilities for your subclasses.
These will typically be the output of your final Linear
layer (without any
subsequent non-linear activations). You would normally pass these directly
into CrossEntropyLoss
.
Conceptually, you want to combine your subclass probabilities into superclass
probabilities that you would then pass into a superclass cross-entropy loss
function.
For numerical reasons (essentially the same reasons that CrossEntropyLoss
takes log-probabilities rather than plain probabilities), you should do all of this
in log-space – that is, always work with log-probabilities without ever explicitly
converting them to probabilities – combining your subclass log-probability
predictions into superclass log-probabilities (that you would then pass into a
superclass-level CrossEntropyLoss
).
First pass the unnormalized subclass log-probabilities through log_softmax()
to convert them to normalized log-probabilities.
At this point, conceptually, you would use exp()
to obtain subclass probabilities,
you would sum the subclass probabilities within a given superclass to obtain the
probability for that superclass, and then call log()
to obtain that superclass’s
log-probability.
But we wish to perform this manipulation is log-space, so, instead, we will select
the subclass log-probabilities for the subclasses of a given superclass and “add”
them together with logsumexp()
to obtain that superclass’s log-probability
“directly.”
Consider:
>>> import torch
>>> print (torch.__version__)
2.1.0
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> # example with six subclasses and two superclasses
>>>
>>> # subclass superclass
>>>
>>> # 0: bat 1: mammal
>>> # 1: bear 1: mammal
>>> # 2: crow 0: bird
>>> # 3: dog 1: mammal
>>> # 4: pigeon 0: bird
>>> # 5: robin 0: bird
>>>
>>> class_map = torch.tensor ([ # membership of subclass (column) in superclass (row)
... [False, False, True, False, True, True],
... [ True, True, False, True, False, False]
... ])
>>>
>>> subu = torch.randn (6) # unnormalized subclass log-probabilities
>>> subn = subu.log_softmax (dim = 0) # normalized subclass log-probabilities
>>>
>>> subn
tensor([-3.0564, -1.2996, -2.2345, -1.1579, -2.5913, -1.6918])
>>> subn.exp()
tensor([0.0471, 0.2727, 0.1070, 0.3141, 0.0749, 0.1842])
>>>
>>> supern = torch.empty (class_map.size (0)) # storage for superclass (normalized) log-probabilities
>>> for i in range (class_map.size (0)): # compute log-probability for each superclass
... supern[i] = subn[class_map[i]].logsumexp (dim = 0)
...
>>> supern
tensor([-1.0047, -0.4559])
>>> supern.exp()
tensor([0.3661, 0.6339])
*) For example, what is the input data to your model? Images? Time series?
Sets of disparate descriptive values? And what does that data mean? What
are your classes? Is your training data balanced or unbalanced? How much
training data do you have? What are your most important performance metrics?
Best.
K. Frank