Which type of CUDA Math API does torch.sin() use?

CUDA has two types of math API to compute transcendental functions such as sin, cos, sqrt, exp, etc.
One is the accurate version, e.g.

__device__​ double sin ( double  x )

which can be found in
https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__DOUBLE.html#group__CUDA__MATH__DOUBLE_1g3ebbca20a2937d1fe51329402880df85
The other is the fast approximate version, e.g.

__device__​ float __sinf ( float  x )

which can be found in
https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__INTRINSIC__SINGLE.html#group__CUDA__MATH__INTRINSIC__SINGLE_1gfa0ea4b2cee94521792ead0deb03addb

MY question is, if I call torch.sin(), which type of CUDA API does it use?
If it uses the accurate version, how can I use the fast approximate version?

The native implementation should use std::sin as seen here.
To use the fast approximate implementation you could use a custom CUDA extension.

1 Like

Thanks for your fast reply! I solved the problem following https://pytorch.org/tutorials/advanced/cpp_extension.html.

BWT I found my implementation of the fast version is slower than I directly call torch.xxx. I know this is maybe a CUDA problem, but it’ll be very appreciated if you could help me.

for example, when I implement the fast version of exp().
In the cuda file, my implementation is

template <typename scalar_t>
__global__ void fastexp_cuda_kernel(
    const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> input,
    torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output) {
  const int index = blockIdx.x * blockDim.x + threadIdx.x;
  output[index] = expf(input[index]);

torch::Tensor fastexp_cuda(torch::Tensor input) {
  const auto size = input.size(0);
  auto output = torch::empty_like(input);
  const int threads = 1024;
  const dim3 blocks((size + threads - 1) / threads);
  AT_DISPATCH_FLOATING_TYPES(input.type(), "fastexp_cuda", ([&] {
    fastexp_cuda_kernel<scalar_t><<<blocks, threads>>>(
        input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
        output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>());
  }));
  return output;
}
}

I use the JIT compiling method, I do find it can be faster when I use --use_fast_math flag.
BUT anyway it is slower than torch.exp().
does torch.exp() use any optimization? Or does my implementation have any redundant?