Customized efficient implementation of scaled_dot_product_attention

I would like to implement one customized function similar to scaled_dot_product_attention, as shown in this doc. Specifically, suppose I want to replace the softmax with some modified function with learned parameter w, but maintaining the potential efficient implementation from the lower end of CUDA. I notice from this discussion about the cuda function code.

What could be the best practice to implement what I expect? I know how to implement this with pytorch but I am afraid this might be not efficient during training compared to the default efficient implementation.

Or is there recommended doc of customized CUDA function for pytorch? Maybe this one?

Thanks!

Yes, C++/CUDA extensions are a supported approach to write custom implementations and to pipe them to PyTorch as fucntions or layers.

1 Like

If I would like to customize nn.functional.scaled_dot_product_attention, is implementing the three functions at::_scaled_dot_product_flash_attention at::_scaled_dot_product_efficient_attention at::_scaled_dot_product_attention_math enough, or I also need to implement other I/O functions? BTW, where are the definition of at::_scaled_dot_product_flash_attention and at::_scaled_dot_product_efficient_attention?