I wrote cuda kernel myself, and I can use AT_DISPATCH_FLOATING_TYPES_AND_HALF to let it support double/float32/float16, but it does not support bfloat16. How can I let a kernel to work with all of the sereis of double/float32/float16/bfloat16 ?
I tried this, and my op can support double/float/fp16/bf16 now. However, I got to find that double/float/fp16 has same training results as pytorch autograd, but bf16 mode brings a big gap between pytorch autograd and my cuda kernel. Besides, bf16 run much slower than the other 3. Do you know why bf16 is different from the other three?