I want to use custom triton kernels with pytorch for training, how can i do this with minimal code change? What i mean is i dont want to define a custom class and do a apply on the particular kernel. Is there a different way to do this?
thanks, but i was wondering if theres another way of doing this, without using torch autograd function, basically the class/function youd mentioned here