I think you could try to use the raw loss output (via reduction='none'
), set the unwanted loss entries to zero, reduce the loss, and calculate the gradients via loss.backward()
. Unsure, if there is a better way to mask the loss.
2 Likes