If I want to compute the loss based on only the labels of the current batch, I can select those classes from the output logits by index, then create a new logits output and compute the loss using the unique relabeled [from 0 to n] targets of the current batch.
The code is as follows:
if training:
# Use labels trick
# Get current batch labels (and sort them for reassignment)
unq_lbls = labels.unique().sort()[0]
# Assign new labels (0,1 ...)
for lbl_idx, lbl in enumerate(unq_lbls):
labels[labels == lbl] = lbl_idx
# Calcualte loss only over the heads appear in the batch:
loss = criterion(outputs[:, unq_lbls], labels)
else:
loss = criterion(outputs, labels)
My question is how to use this loss to update the weights of the classifier and the model. I think I should freeze the weights of those classes that are not in the current batch and only update the relevant connections per batch. I would like to have a second opinion on this, whether it is correct or not. And if yes, a hint on the implementation would be great.