Weighted loss function, with learnable weights

Dear all,
I want to ask you for some help. I am training a dual-path CNN, where one path processes the image in a holistic manner, where the other path processes the same image but patch-wise, which means I decompose N_patches from the same image, and feed all patches in a second CNN, where each single patch goes in the same CNN (sharing weights).

My idea is to make a combined loss function, where in the local processing path (from the two path model), the local losses are calculated, which means each patch corresponds to its own loss. Now my question is, how to make weighting between the local losses, where the weights are learnable during the training.
In the final stage I combine the global loss and the local loss (the local loss is then averaged).

I already trained the combined Cnn model, where in the final stage i combined the global image features and the local image features, but now i want to try to learn local losses from the local processing path.

I did something like this:

class PatchRelevantLoss(nn.Module):
   def __init__(self, n_patches):
        super(PatchRelevantLoss, self).__init__()
        self.ce = nn.CrossEntropyLoss(reduction = 'none')
        self.n_patches = n_patches
        self.weights = .... nn.Parameter or something (dim: n_patches)
    
    def forward(self, feature_matrix, labels):
        loss = ()
        for features, weight in zip(feature_matrix, self.weights):
             loss += ((self.ce(features, labels)*weight ).mean(0), )
        loss = torch.stack(loss, 0)
        return loss

Shapes: feature_matrix has (n_patches, batch_size, feature_dimension)
Now the idea is to make the wieghts learnable.

Here is the basic idea on the picture.

I am new to Pytorch and i’ll be thankful for some help.
Cheers.

IIUC, with your architecture, these weights are hyperparameters (equivalent to learning rates) and hence only externally tunable

Thanks for the answer, I need more clarification unfortunately. Does that mean I need to define them in the CNN architecture and to make an additional output from the network which will serve later on as an input to the custom loss function?
Now I see there are some implementation on how to make trainable weights in multi-task loss, but I am not exactly sure about this specific problem.

Loss terms are not trainable, because loss*w is minimized by w=0 (or other degenerate results with negative numbers). Even constraints like sum(w)=1 don’t have sensible interpretations - normally you would combine predictions instead.

1 Like

Thanks for the guidance. That means I can optimize upon W based on the prediction outputs, and then to take the average of them? For example to define the learnable parameters in the CNN model and as an output to take the weighted average while W are learnable?

Generally, you would want to have autonomous submodels before creating an “ensemble” - you can then take the weighted average of predictions with softmax-ed trained combination weights.

Without this - i.e. if you train everything at once with one objective(loss) across all parts, you’ll just get a higher capacity model where parts (and combination weights) train to mutually compensate output errors from other parts.

1 Like

Thank you for the insights, learned a lot from you. Yes, the idea is to make the local processing path to compensate the missing information from the global processing path (where the image is processed in a holistic manner), while the trained submodels in the local processing path to be trained with sofmax-ed trained combination weights.

So basically, as you said, I can use only one loss function across all parts (bigger capacity model) or either several losses where each of them corresponds to a single patch?

What’s more meaningful to you, which way to go, many outputs from the local branch, where each output is multiplied by a learnable weight while optimizing with several losses (each patch one loss), or weighted average ensemble trained with a single loss function?

For corrections, perhaps enforce an order:
y0 = globalCnn(x; pars)
y1 = f(y0, localCnn(x_patch[0]; lpars)) + y0
y2 = f(y1, localCnn(x_patch[1]; lpars)) + y1
…
loss = criterion(yN, target_y)

there are many possibilities for f(), the key is that this computation is no longer permutation invariant

more generally, patch specific losses don’t feel right to me, idk…

1 Like

Thanks, I only have one more question. Since there are trainable parameters W, what is the most common way to initialize for this specific task?

Is this a good practice :

self.weights = nn.Parameter(torch.FloatTensor(self.n_patches)).unsqueeze(1).cuda()
nn.init.uniform_(self.weights)

And then to normalize so that their sum(w) = 1? Or its enough to initialize with uniform without normalization? Maybe other initialization procedure will be more suitable in this case?

if you need vector(s) with sum(w)=1:
self.w = nn.Parameter(torch.zeros(n))
then in forward:
w = self.w.softmax(-1)

you only need random / asymmetric init if components are non-identifiable (swappable)

1 Like

Thank you mate for the help. This is part of my research project, ill inform how it goes. Will see if this approach give a performance boost.
Have a great day.