Update ResNet18 Classifer according to the classes in the batch

If I want to compute the loss based on only the labels of the current batch, I can select those classes from the output logits by index, then create a new logits output and compute the loss using the unique relabeled [from 0 to n] targets of the current batch.
The code is as follows:

if training:
    # Use labels trick
    # Get current batch labels (and sort them for reassignment)
    unq_lbls = labels.unique().sort()[0]
    # Assign new labels (0,1 ...)
    for lbl_idx, lbl in enumerate(unq_lbls):
        labels[labels == lbl] = lbl_idx
    # Calcualte loss only over the heads appear in the batch:
    loss = criterion(outputs[:, unq_lbls], labels)
else:
    loss = criterion(outputs, labels)

My question is how to use this loss to update the weights of the classifier and the model. I think I should freeze the weights of those classes that are not in the current batch and only update the relevant connections per batch. I would like to have a second opinion on this, whether it is correct or not. And if yes, a hint on the implementation would be great.