I would like to implement one customized function similar to scaled_dot_product_attention
, as shown in this doc. Specifically, suppose I want to replace the softmax with some modified function with learned parameter w, but maintaining the potential efficient implementation from the lower end of CUDA. I notice from this discussion about the cuda function code.
What could be the best practice to implement what I expect? I know how to implement this with pytorch but I am afraid this might be not efficient during training compared to the default efficient implementation.
Or is there recommended doc of customized CUDA function for pytorch? Maybe this one?
Thanks!