NLL training using a weighted sum of softmax-normalized predictions

I have developed a classification model, where the final class prediction is an attention-weighted sum of multiple softmax-normalized probability distributions.

This is an instantiation of Multiple Instance Learning (MIL) – a classic use-case is computer vision: You make separate predictions about many individual patches in the image, but only have training labels for the image as a whole. The prediction for the whole image is a sum of the individual ones, weighted by attention.

Here is a stripped-down example with 5 classes, where the final prediction is a weighted sum of 3 individual predictions (I use a batch size of 1 for simplicity):

# individual predictions obtained via softmax 
# [batch_size x num_patches x num_classes]
individual_preds = torch.tensor([[[0.10, 0.20, 0.20, 0.25, 0.25],
                                  [0.15, 0.15, 0.20, 0.20, 0.30],
                                  [0.25, 0.25, 0.20, 0.30, 0.00]]],

# weights of individual predictions [batch_size x num_patches x 1]
attn = torch.tensor([[[0.25],

# final prediction as a weighted sum
# of individual ones [batch_size x num_classes]
pred = torch.sum(attn * individual_preds, dim=1)

# 'manually' taking the log probabilities for NLL
log_pred = pred.log()

# pseudo-target for loss [batch_size]
target = torch.tensor([1])

# NLL loss
crit = nn.NLLLoss()
loss = crit(pred, target)

I know that it is recommended that classification tasks use a combination of LogSoftMax and NLLLoss (or simply CrossEntropyLoss), to avoid numerical instabilities.

In this case, I cannot think of a way to avoid computing softmax-normalized probabilities and then ‘manually’ taking the log, before pushing through NLLLoss. The problem is that a weighted sum of predictions is required, and this only makes sense in probability space – if individual predictions were already in log space, it is impossible to combine them (the log of a sum doesn’t decompose to anything usable).

Should I worry about “numerical instability” because I’m using SoftMax+Log+NLLLoss separately? Is there a way to achieve what I want that I am currently missing?