If I want to compute the loss based on only the labels of the current batch, I can select those classes from the output logits by index, then create a new logits output and compute the loss using the unique relabeled [from 0 to n] targets of the current batch.
The code is as follows:
if training: # Use labels trick # Get current batch labels (and sort them for reassignment) unq_lbls = labels.unique().sort() # Assign new labels (0,1 ...) for lbl_idx, lbl in enumerate(unq_lbls): labels[labels == lbl] = lbl_idx # Calcualte loss only over the heads appear in the batch: loss = criterion(outputs[:, unq_lbls], labels) else: loss = criterion(outputs, labels)
My question is how to use this loss to update the weights of the classifier and the model. I think I should freeze the weights of those classes that are not in the current batch and only update the relevant connections per batch. I would like to have a second opinion on this, whether it is correct or not. And if yes, a hint on the implementation would be great.