How to remove backpropagation for specific tokens from the output of a transformer decoder only?

I have a binary mask that masks the training loss for tokens that I don’t want to be updated in backpropagation. Until now I only set the loss of the tokens I didn’t want to train to zero. But now, I want to completely remove backpropagation for these tokens, to gain speed in training. Does anyone have any idea how to make this modification when I’m using this CustomTrainer?

*I’m trying this on the Phi-2 model

class CustomTrainer(transformers.Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):

        labels = inputs.pop('labels')
        loss_mask = inputs.pop('loss_mask')

        # forward

        outputs = model(**inputs)

        logits = outputs.logits

        if torch.isnan(logits).any():
            print('NaN detected in logits')
            print(logits)

        probs = nn.functional.softmax(logits, dim=-1)

        predicted_token_ids = torch.argmax(probs, dim=-1)

        loss_fct = nn.CrossEntropyLoss(reduction='none')
        losses = loss_fct(logits.view(-1, self.model.config.vocab_size), labels.view(-1))

        losses = losses.view(-1, inputs['input_ids'].size(1))

        masked_loss = losses * loss_mask

        loss = masked_loss.sum() / (loss_mask.sum() + 1e-9)
        batch_size, seq_length = inputs['input_ids'].size()

        return (loss, outputs) if return_outputs else loss

    def get_train_dataloader(self):

        train_dataset = self.train_dataset
        data_collator = self.data_collator

        dataloader_params = {
            'batch_size': self.args.train_batch_size,
            'collate_fn': data_collator,
            'num_workers': self.args.dataloader_num_workers,
            'pin_memory': self.args.dataloader_pin_memory
        }

        if not isinstance(train_dataset, torch.utils.data.IterableDataset):
            dataloader_params['shuffle'] = True
            dataloader_params['drop_last'] = self.args.dataloader_drop_last

        return DataLoader(train_dataset, **dataloader_params)

    def get_eval_dataloader(self, eval_dataset=None):
        if eval_dataset is None:
            eval_dataset = self.eval_dataset
        data_collator = self.data_collator

        dataloader_params = {
            'batch_size': self.args.eval_batch_size,
            'collate_fn': data_collator,
            'num_workers': self.args.dataloader_num_workers,
            'pin_memory': self.args.dataloader_pin_memory,
            'shuffle': False,
            'drop_last': self.args.dataloader_drop_last,
        }

        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
            dataloader_params.pop('shuffle', None)
            dataloader_params.pop('drop_last', None)

        return DataLoader(eval_dataset, **dataloader_params)

I’ve already tried some modifications, such as detaching the logits in the logits tensor lines that I don’t want backpropagation to go through, but I don’t know if this is the right way.

I need a backpropagation like this, where only the first EOS are updated in the training:

Autograd acts on whole tensors, so there’s no simple way to control what gradients get computed at a smaller granularity unfortunately.

You’d need to rewrite your operations such that you have separate tensors that completely require grad and separate tensors that completely do not require grad, does that sound reasonable to do in your case?

1 Like

Thanks for the answer. I tried filtering the logits using my mask before passing it to loss to apply this idea you mentioned. In the end, a tensor of [512, VOCAB_SIZE] was left with [2, VOCAB_SIZE], these being just the two tokens in which I want to update the model. However, analyzing the FLOPs generated by the huggingface Trainer, they didn’t change anything. Now I don’t know if the technique is wrong or the FLOPs in this case really wouldn’t indicate anything, but I believe they would have to indicate, especially considering that backpropagation costs twice as much computation as forward. And besides, I’m reducing from 512 tokens to just 2, there should be a big difference in my point of view.

class CustomTrainer(transformers.Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
 
        labels = inputs.pop('labels')
        loss_mask = inputs.pop('loss_mask')

        outputs = model(**inputs)
        logits = outputs.logits


        if torch.isnan(logits).any():
            print('NaN detected in logits')
            print(logits)

        inactive_positions = ~loss_mask.view(-1).bool()


        logits.view(-1, model.config.vocab_size)[inactive_positions] = logits.view(-1, model.config.vocab_size)[inactive_positions].detach()

        active_positions = loss_mask.view(-1).bool()
        active_logits = logits.view(-1, model.config.vocab_size)[active_positions]
        active_labels = labels.view(-1)[active_positions]

        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(active_logits, active_labels)

        return (loss, outputs) if return_outputs else loss

Trainer output: