Gradient Checkpointing with Transformers BERT model

I’m trying to apply gradient checkpointing to the huggingface’s Transformers BERT model. I’m skeptical if I’m doing it right, though! Here is my code snippet wrapped around the BERT class:

class Bert(nn.Module):
    def __init__(self, large, temp_dir, finetune=False):
        super(Bert, self).__init__()
            self.model = BertModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=temp_dir)

        self.finetune = finetune # either the bert should be finetuned or not... default(True)

    def custom_bert_forward(self, module):
        def custom_forward(*inputs):
            output = module(inputs[0], attention_mask=inputs[1], token_type_ids=inputs[2])
            return output

        return custom_forward
    
    def forward(self, x, segs, mask):
        if (self.finetune):
            
            ## (1) without checkpointing
            top_vec, _ = self.model(x.long(), attention_mask=mask.long(), token_type_ids=segs.long())
            
            ## (2) with checkpointing
            # top_vec = checkpoint.checkpoint(
            #     self.custom_bert_forward(self.model),
            #     x, mask, segs,
            # )

        else:
            self.eval()
            with torch.no_grad():
                top_vec, _ = self.model(x, attention_mask=mask, token_type_ids=segs)
        return top_vec

As I’m checkpointing the BERT’s forward function, the memory usage drops significantly (~1/5), but I’m getting relatively inferior performance compared to non-checkpointing, in terms of the metrics (for my task, which is summarization) that I’m calculating on the validation set. Another observation: while checkpointing, the model’s training speed also increases considerably which is totally odd to what I have learned from gradient checkpointing. Is there any problem with the implementation? or have I done any part wrong?

Based on the code snippet it seems you are calling the complete model via checkpoint, which should not save any memory, or is this module just part of a larger model?
However, based on the issue description is seems that parts of the model might not be executed at all.

@ptrblck thanks for your response!

Yes, this model is just part of a larger network, i.e., top_vec which is the output of this model is being used by another model. I see top_vec as a vector that has the encoded version of vector x (i…e, src) by the BERT. I a sense, the weights associated with this class should be updated (i.e., learned) during training. Upon my investigations, I noticed that this part of the model consumes much of a memory, so that I thought it’d be better to checkpoint it. Isn’t that true argument, wherever we have learning (i.e., the update of model parameters), we can use gradient checkpointing??


I also noticed that there’s a recently implemented option in Huggingface’s BERT which allows us to apply gradient checkpointing easily. That’s an argument that is specified in BertConfig and then the object is passed to BertModel.from_pretrained. I also tried that, but have the same above issues that I mentioned: 1) the performance does not yield to that of setting without gradient-checkpointing. 2) the training is much faster. While I expect it to be slower, in expense of having memory usage reduced, as Gradient Checkpointing makes the model recompute activations wherever necessary instead of retrieving them, which totally results in more computations, and more training time.

I would have the same assumptions and also think that something might not work as expected.
Did you just recently encountered this issue and was it working before?
Also, which PyTorch version are you using?

@ptrblck I’m having this issue since the first time I implemented it, and was not working before as my assumptions. The code that you see runs on torch 1.1.0 (BERTSUM main implementation), but I also tested it on BART Huggingface which uses Pytorch > 1.4.0.

Does pytorch version affect the checkpointing? I thought this before, and searched if gradient checkpointing has been added since a certain version of pytorch, but couldn’t find anything useful.

Another observation: what if the returned value has Gradient None? I’ve had this issue: Checkpoint with no grad requiring inputs PROBLEM as well, which is solved when the required_gradient is set to True for the input arguments. I should mention that in HuggingFace’s BertModel, there is src argument that does not need gradients at all, but to fade the warning, I have to set its gradients to True. I this also an issue?

The second issue might be disabling the gradient calculation.
After you’ve set requires_grad=True in the input, are you still seeing a speedup and bad model performance?