CUDA has two types of math API to compute transcendental functions such as sin, cos, sqrt, exp, etc.
One is the accurate version, e.g.
__device__ double sin ( double x )
which can be found in
https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__DOUBLE.html#group__CUDA__MATH__DOUBLE_1g3ebbca20a2937d1fe51329402880df85
The other is the fast approximate version, e.g.
__device__ float __sinf ( float x )
which can be found in
https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__INTRINSIC__SINGLE.html#group__CUDA__MATH__INTRINSIC__SINGLE_1gfa0ea4b2cee94521792ead0deb03addb
MY question is, if I call torch.sin(), which type of CUDA API does it use?
If it uses the accurate version, how can I use the fast approximate version?
The native implementation should use std::sin
as seen here.
To use the fast approximate implementation you could use a custom CUDA extension.
1 Like
Thanks for your fast reply! I solved the problem following https://pytorch.org/tutorials/advanced/cpp_extension.html
.
BWT I found my implementation of the fast version is slower than I directly call torch.xxx. I know this is maybe a CUDA problem, but it’ll be very appreciated if you could help me.
for example, when I implement the fast version of exp().
In the cuda file, my implementation is
template <typename scalar_t>
__global__ void fastexp_cuda_kernel(
const torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> input,
torch::PackedTensorAccessor32<scalar_t, 1, torch::RestrictPtrTraits> output) {
const int index = blockIdx.x * blockDim.x + threadIdx.x;
output[index] = expf(input[index]);
torch::Tensor fastexp_cuda(torch::Tensor input) {
const auto size = input.size(0);
auto output = torch::empty_like(input);
const int threads = 1024;
const dim3 blocks((size + threads - 1) / threads);
AT_DISPATCH_FLOATING_TYPES(input.type(), "fastexp_cuda", ([&] {
fastexp_cuda_kernel<scalar_t><<<blocks, threads>>>(
input.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>(),
output.packed_accessor32<scalar_t, 1, torch::RestrictPtrTraits>());
}));
return output;
}
}
I use the JIT compiling method, I do find it can be faster when I use --use_fast_math
flag.
BUT anyway it is slower than torch.exp()
.
does torch.exp()
use any optimization? Or does my implementation have any redundant?