Self.generate() gives OOM in training

Hi, in part of my implemented code in BartForConditionalGeneration class of Huggingface, I need to call self.generate(**kwargs) function for token generation. When I call this function in forward method of BartForConditionalGeneration class, I’m getting OOM error on GPU. However, I can perform validation steps through self.evaluate function of trainer class while training.

Here’s the truncated snippet I have currently in the forward method of BartForConditionalGeneration:


...
output = super().forward(
            input_ids=None,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            encoder_outputs=encoder_outputs,
            encoder_outputs_weighted=encoder_outputs_weighted,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=True,
            # output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

# this line gives OOM error on GPU
with torch.no_grad():
        generated_tokens = self.generate(encoder_outputs=BaseModelOutput(last_hidden_state=encoder_outputs[0]), attention_mask=attention_mask)

...

I have tried a bunch of tricks to resolve this error: like using lower batch_size, and num_beams; however, still getting the same error. What are the possible causes of this behaviour?

Hi, I’m not familiar with Hugging Face so I’m not sure if this will solve your problem, but maybe I can give you some ideas.

Could you post the error message? It should tell you amount of allocated memory vs how much memory PyTorch is trying to allocate. Also do you have some sort of performance monitor to see if your GPU is OOM?

I believe evaluation passes are performed without gradients where regular forward passes are. If you don’t need to train, this would unnecessarily increase the space used by the GPU. Could you make sure you are using things like

model.eval()
with torch.no_grad():
or
requires_grad=False when instantiating tensors

I’m not familiar with this specific model but in many transformers generating token embeddings is a trainable layer.