Hi, in part of my implemented code in BartForConditionalGeneration class of Huggingface, I need to call self.generate(**kwargs) function for token generation. When I call this function in forward method of BartForConditionalGeneration class, I’m getting OOM error on GPU. However, I can perform validation steps through self.evaluate function of trainer class while training.
Here’s the truncated snippet I have currently in the forward method of BartForConditionalGeneration:
...
output = super().forward(
input_ids=None,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
encoder_outputs_weighted=encoder_outputs_weighted,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=True,
# output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# this line gives OOM error on GPU
with torch.no_grad():
generated_tokens = self.generate(encoder_outputs=BaseModelOutput(last_hidden_state=encoder_outputs[0]), attention_mask=attention_mask)
...
I have tried a bunch of tricks to resolve this error: like using lower batch_size, and num_beams; however, still getting the same error. What are the possible causes of this behaviour?