Half type dependency problem

Hi,

I am trying to implement a cuda kernel as pytorch extension.

Part of my code is like this:

To be specific, I want the kernel to work with double/float/fp16/bf32, and I used the function of __shfl_down_sync in my kernel.

When I compile, I got the error of:

It seems that there are two definitions of __shfl_down_sync, one works with __half and the other works with float.

Do you know how to make this work?