I essentially want to replace the product operation within matrix multiplication to another type of operation. I have very little knowledge when it comes to writing a custom pytorch kernel, and so, I would like to take advantage of everything behind torch.matmul, but would need to make a few changes to the underlying code to change the operation. I have been looking everywhere, but is there an easy way to do this before I hash together a terribly optimized cuda kernel?
The source code for pytorch matrix multiplication is pointed to in this topic, but not really sure how to utilize this, and again, if there is an easier way, that would be great.
The aim is to have an efficient operation that can be used on the gpu (cuda).
That explains the pain in finding the source code. The triton tutorial looks very promising, thank you. Overall, this rabbit hole I have fallen into seems to be a very rare use case and thus going to take a lot more work than it is worth now. I may focus on a CPU implementation for now instead.