Compute sin and cos simultaneously on the same input

A pattern found in the wild at: DiffusionInst/head.py at main · chenhaoxing/DiffusionInst · GitHub
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1).

This is often found in positional encoding computation / maybe some random fourier features.

Is there a nice fused fast way of doing this in PyTorch?

This is equivalent to taking complex exponential (but could be a feature request for a separate alias as torch.sincos(x)), but in NumPy at least it’s a bit slower than computing sin and cos separately (despite that some sincos_ps SSE recipe exists) numpy - Is there a fast Way to return Sin and Cos of the same value in Python? - Stack Overflow / http://gruntthepeon.free.fr/ssemath/sse_mathfun.h

Ideally, this pattern can also be fused in the future.

That’s an interesting use case and yes, it should be possible to write a fused version of this operation.
One approach would be to use the nvFuser Python frontend, which is available in the latest nightly releases.
Here is a complete code snippet showing the fused kernel with the limitation of the torch.cat operation. This op is currently not supported and still WIP, so you would still need to execute an additional eager kernel.

import torch
from torch.profiler import profile, record_function, ProfilerActivity
from torch._C._nvfuser import Fusion, FusionCache, FusionDefinition, DataType
from typing import List


def sincos_eager(embeddings):
    return torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)


def nvFuser_sincos(fd: FusionDefinition):
    embeddings = fd.define_tensor(symbolic_sizes=[-1, -1], contiguous=[True, True], dtype=DataType.Float)
    output1 = fd.ops.sin(embeddings)
    output2 = fd.ops.cos(embeddings)
    #output = fd.ops.cat((output1, output2), dim=dim)
    fd.add_output(output1)
    fd.add_output(output2)


sincos = Fusion()
with FusionDefinition(sincos) as fd:
    nvFuser_sincos(fd)


def python_frontend_sincos(
    fs: Fusion, input: torch.Tensor
) -> List[torch.Tensor]:
    out = fs.execute([input])
    return out

# setup
embeddings = torch.randn(1024, 1024, device='cuda')

# compare
out_pynvf = python_frontend_sincos(sincos, embeddings)
out_pynvf = torch.cat(out_pynvf, -1)
out_eager = sincos_eager(embeddings)

print((out_pynvf - out_eager).abs().max())
# tensor(0., device='cuda:0')

# profile
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("embedding_eager"):
        sincos_eager(embeddings)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                        cudaLaunchKernel        88.10%     348.000us        88.10%     348.000us     116.000us       0.000us         0.00%       0.000us       0.000us             3  
#                                   cudaDeviceSynchronize        11.90%      47.000us        11.90%      47.000us      47.000us       0.000us         0.00%       0.000us       0.000us             1  
# void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      12.000us        24.49%      12.000us      12.000us             1  
# void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      12.000us        24.49%      12.000us      12.000us             1  
# void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us      25.000us        51.02%      25.000us      25.000us             1  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  

with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("embedding_nvFuser"):
        python_frontend_sincos(sincos, embeddings)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                   cudaDeviceSynchronize       100.00%      18.000us       100.00%      18.000us      18.000us       0.000us         0.00%       0.000us       0.000us             1  
# CudaCodeGen::kernel9(CudaCodeGen::Tensor<float, 2>, ...         0.00%       0.000us         0.00%       0.000us       0.000us      17.000us       100.00%      17.000us      17.000us             1  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
2 Likes

Created a follow-up discussion at [feature request, idea] Fused torch.sincos(x) or cossin(x) - somewhat complex exponential · Issue #90559 · pytorch/pytorch · GitHub