How to Implement Flash Attention in a Pre-Trained BERT Model on custom dataset?

I have a fully pre-trained model on my custom dataset. In its current implementation, flash attention is not used, so the model is running without it. I want to enable flash attention in the model.

I have two specific questions:

  1. If I modify the code by simply torch.nn.functional.scaled_dot_product_attention flash attention implementation, will that enable flash attention correctly?
  2. After making this modification, do I need to finetune the model, or can I use it as-is without further training?

I would greatly appreciate any clarification or guidance on this! Thanks in advance.