How to remove backpropagation for specific tokens from the output of a transformer decoder only?

Thanks for the answer. I tried filtering the logits using my mask before passing it to loss to apply this idea you mentioned. In the end, a tensor of [512, VOCAB_SIZE] was left with [2, VOCAB_SIZE], these being just the two tokens in which I want to update the model. However, analyzing the FLOPs generated by the huggingface Trainer, they didn’t change anything. Now I don’t know if the technique is wrong or the FLOPs in this case really wouldn’t indicate anything, but I believe they would have to indicate, especially considering that backpropagation costs twice as much computation as forward. And besides, I’m reducing from 512 tokens to just 2, there should be a big difference in my point of view.

class CustomTrainer(transformers.Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
 
        labels = inputs.pop('labels')
        loss_mask = inputs.pop('loss_mask')

        outputs = model(**inputs)
        logits = outputs.logits


        if torch.isnan(logits).any():
            print('NaN detected in logits')
            print(logits)

        inactive_positions = ~loss_mask.view(-1).bool()


        logits.view(-1, model.config.vocab_size)[inactive_positions] = logits.view(-1, model.config.vocab_size)[inactive_positions].detach()

        active_positions = loss_mask.view(-1).bool()
        active_logits = logits.view(-1, model.config.vocab_size)[active_positions]
        active_labels = labels.view(-1)[active_positions]

        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(active_logits, active_labels)

        return (loss, outputs) if return_outputs else loss

Trainer output: