I have a fully pre-trained model on my custom dataset. In its current implementation, flash attention is not used, so the model is running without it. I want to enable flash attention in the model.
I have two specific questions:
- If I modify the code by simply
torch.nn.functional.scaled_dot_product_attention
flash attention implementation, will that enable flash attention correctly? - After making this modification, do I need to finetune the model, or can I use it as-is without further training?
I would greatly appreciate any clarification or guidance on this! Thanks in advance.