Thanks for the answer. I tried filtering the logits using my mask before passing it to loss to apply this idea you mentioned. In the end, a tensor of [512, VOCAB_SIZE] was left with [2, VOCAB_SIZE], these being just the two tokens in which I want to update the model. However, analyzing the FLOPs generated by the huggingface Trainer, they didn’t change anything. Now I don’t know if the technique is wrong or the FLOPs in this case really wouldn’t indicate anything, but I believe they would have to indicate, especially considering that backpropagation costs twice as much computation as forward. And besides, I’m reducing from 512 tokens to just 2, there should be a big difference in my point of view.
class CustomTrainer(transformers.Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop('labels')
loss_mask = inputs.pop('loss_mask')
outputs = model(**inputs)
logits = outputs.logits
if torch.isnan(logits).any():
print('NaN detected in logits')
print(logits)
inactive_positions = ~loss_mask.view(-1).bool()
logits.view(-1, model.config.vocab_size)[inactive_positions] = logits.view(-1, model.config.vocab_size)[inactive_positions].detach()
active_positions = loss_mask.view(-1).bool()
active_logits = logits.view(-1, model.config.vocab_size)[active_positions]
active_labels = labels.view(-1)[active_positions]
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(active_logits, active_labels)
return (loss, outputs) if return_outputs else loss
Trainer output: