How to remove backpropagation for specific tokens from the output of a transformer decoder only?

I have a binary mask that masks the training loss for tokens that I don’t want to be updated in backpropagation. Until now I only set the loss of the tokens I didn’t want to train to zero. But now, I want to completely remove backpropagation for these tokens, to gain speed in training. Does anyone have any idea how to make this modification when I’m using this CustomTrainer?

*I’m trying this on the Phi-2 model

class CustomTrainer(transformers.Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):

        labels = inputs.pop('labels')
        loss_mask = inputs.pop('loss_mask')

        # forward

        outputs = model(**inputs)

        logits = outputs.logits

        if torch.isnan(logits).any():
            print('NaN detected in logits')
            print(logits)

        probs = nn.functional.softmax(logits, dim=-1)

        predicted_token_ids = torch.argmax(probs, dim=-1)

        loss_fct = nn.CrossEntropyLoss(reduction='none')
        losses = loss_fct(logits.view(-1, self.model.config.vocab_size), labels.view(-1))

        losses = losses.view(-1, inputs['input_ids'].size(1))

        masked_loss = losses * loss_mask

        loss = masked_loss.sum() / (loss_mask.sum() + 1e-9)
        batch_size, seq_length = inputs['input_ids'].size()

        return (loss, outputs) if return_outputs else loss

    def get_train_dataloader(self):

        train_dataset = self.train_dataset
        data_collator = self.data_collator

        dataloader_params = {
            'batch_size': self.args.train_batch_size,
            'collate_fn': data_collator,
            'num_workers': self.args.dataloader_num_workers,
            'pin_memory': self.args.dataloader_pin_memory
        }

        if not isinstance(train_dataset, torch.utils.data.IterableDataset):
            dataloader_params['shuffle'] = True
            dataloader_params['drop_last'] = self.args.dataloader_drop_last

        return DataLoader(train_dataset, **dataloader_params)

    def get_eval_dataloader(self, eval_dataset=None):
        if eval_dataset is None:
            eval_dataset = self.eval_dataset
        data_collator = self.data_collator

        dataloader_params = {
            'batch_size': self.args.eval_batch_size,
            'collate_fn': data_collator,
            'num_workers': self.args.dataloader_num_workers,
            'pin_memory': self.args.dataloader_pin_memory,
            'shuffle': False,
            'drop_last': self.args.dataloader_drop_last,
        }

        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
            dataloader_params.pop('shuffle', None)
            dataloader_params.pop('drop_last', None)

        return DataLoader(eval_dataset, **dataloader_params)

I’ve already tried some modifications, such as detaching the logits in the logits tensor lines that I don’t want backpropagation to go through, but I don’t know if this is the right way.

I need a backpropagation like this, where only the first EOS are updated in the training:

Autograd acts on whole tensors, so there’s no simple way to control what gradients get computed at a smaller granularity unfortunately.

You’d need to rewrite your operations such that you have separate tensors that completely require grad and separate tensors that completely do not require grad, does that sound reasonable to do in your case?

1 Like

Thanks for the answer. I tried filtering the logits using my mask before passing it to loss to apply this idea you mentioned. In the end, a tensor of [512, VOCAB_SIZE] was left with [2, VOCAB_SIZE], these being just the two tokens in which I want to update the model. However, analyzing the FLOPs generated by the huggingface Trainer, they didn’t change anything. Now I don’t know if the technique is wrong or the FLOPs in this case really wouldn’t indicate anything, but I believe they would have to indicate, especially considering that backpropagation costs twice as much computation as forward. And besides, I’m reducing from 512 tokens to just 2, there should be a big difference in my point of view.

class CustomTrainer(transformers.Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
 
        labels = inputs.pop('labels')
        loss_mask = inputs.pop('loss_mask')

        outputs = model(**inputs)
        logits = outputs.logits


        if torch.isnan(logits).any():
            print('NaN detected in logits')
            print(logits)

        inactive_positions = ~loss_mask.view(-1).bool()


        logits.view(-1, model.config.vocab_size)[inactive_positions] = logits.view(-1, model.config.vocab_size)[inactive_positions].detach()

        active_positions = loss_mask.view(-1).bool()
        active_logits = logits.view(-1, model.config.vocab_size)[active_positions]
        active_labels = labels.view(-1)[active_positions]

        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(active_logits, active_labels)

        return (loss, outputs) if return_outputs else loss

Trainer output:

Now I don’t know if the technique is wrong or the FLOPs in this case really wouldn’t indicate anything, but I believe they would have to indicate, especially considering that backpropagation costs twice as much computation as forward.

FLOPs indicates more of how compute-bound your operations are, which I’m not sure would necessarily change since you’re still doing matmuls, just smaller ones.

Have you checked whether compute-time changed?

Yes, the computational time remains the same and sometimes it even gets worse, that’s what I find strangest. I tried some other ways of manipulating the tensors to remove this part of the backpropagation like the one I sent previously, but none of them worked. It seems that even after detaching, and reaching the end with a much smaller tensor ([512, 51200] → [2, 51200]) the backpropagation still does the calculation on all of them.

I think how much computation savings you can expect to get really depends on how independent the computation of the single token is from the rest of the tokens.

If in the very last layer, the last token is a function of the entire output of the the previous layer, it doesn’t matter that you can mask out the gradient computations of the other tokens because you still need to backprop through entirety of the first n-1 layers.

The only article that I am aware of that does something along these lines is RHO-1 (https://arxiv.org/pdf/2404.07965), but it is not open source, so I don’t know how they implemented it. In principle, they have a huge performance gain just by eliminating around 40% of the tokens that are not useful during training. I wasn’t supposed to gain in that sense?

The image below shows part of how they eliminate tokens through loss, but nothing specific.

achieving the baseline performance 5-10x faster

I think its not talking performance in terms of per-iteration speedup but rather data efficiency, e.g., you’d reach the same accuracy as the baseline training on fewer tokens.