Hi,
I am trying to implement a cuda kernel as pytorch extension.
Part of my code is like this:
To be specific, I want the kernel to work with double/float/fp16/bf32, and I used the function of __shfl_down_sync
in my kernel.
When I compile, I got the error of:
It seems that there are two definitions of __shfl_down_sync
, one works with __half
and the other works with float
.
Do you know how to make this work?