A pattern found in the wild at: DiffusionInst/head.py at main · chenhaoxing/DiffusionInst · GitHub
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
.
This is often found in positional encoding computation / maybe some random fourier features.
Is there a nice fused fast way of doing this in PyTorch?
This is equivalent to taking complex exponential (but could be a feature request for a separate alias as torch.sincos(x)
), but in NumPy at least it’s a bit slower than computing sin and cos separately (despite that some sincos_ps SSE recipe exists) numpy - Is there a fast Way to return Sin and Cos of the same value in Python? - Stack Overflow / http://gruntthepeon.free.fr/ssemath/sse_mathfun.h
Ideally, this pattern can also be fused in the future.
That’s an interesting use case and yes, it should be possible to write a fused version of this operation.
One approach would be to use the nvFuser Python frontend, which is available in the latest nightly releases.
Here is a complete code snippet showing the fused kernel with the limitation of the torch.cat
operation. This op is currently not supported and still WIP, so you would still need to execute an additional eager kernel.
import torch
from torch.profiler import profile, record_function, ProfilerActivity
from torch._C._nvfuser import Fusion, FusionCache, FusionDefinition, DataType
from typing import List
def sincos_eager(embeddings):
return torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
def nvFuser_sincos(fd: FusionDefinition):
embeddings = fd.define_tensor(symbolic_sizes=[-1, -1], contiguous=[True, True], dtype=DataType.Float)
output1 = fd.ops.sin(embeddings)
output2 = fd.ops.cos(embeddings)
#output = fd.ops.cat((output1, output2), dim=dim)
fd.add_output(output1)
fd.add_output(output2)
sincos = Fusion()
with FusionDefinition(sincos) as fd:
nvFuser_sincos(fd)
def python_frontend_sincos(
fs: Fusion, input: torch.Tensor
) -> List[torch.Tensor]:
out = fs.execute([input])
return out
# setup
embeddings = torch.randn(1024, 1024, device='cuda')
# compare
out_pynvf = python_frontend_sincos(sincos, embeddings)
out_pynvf = torch.cat(out_pynvf, -1)
out_eager = sincos_eager(embeddings)
print((out_pynvf - out_eager).abs().max())
# tensor(0., device='cuda:0')
# profile
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
with record_function("embedding_eager"):
sincos_eager(embeddings)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
# Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
# cudaLaunchKernel 88.10% 348.000us 88.10% 348.000us 116.000us 0.000us 0.00% 0.000us 0.000us 3
# cudaDeviceSynchronize 11.90% 47.000us 11.90% 47.000us 47.000us 0.000us 0.00% 0.000us 0.000us 1
# void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 12.000us 24.49% 12.000us 12.000us 1
# void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 12.000us 24.49% 12.000us 12.000us 1
# void at::native::(anonymous namespace)::CatArrayBatc... 0.00% 0.000us 0.00% 0.000us 0.000us 25.000us 51.02% 25.000us 25.000us 1
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
with record_function("embedding_nvFuser"):
python_frontend_sincos(sincos, embeddings)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
# Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
# cudaDeviceSynchronize 100.00% 18.000us 100.00% 18.000us 18.000us 0.000us 0.00% 0.000us 0.000us 1
# CudaCodeGen::kernel9(CudaCodeGen::Tensor<float, 2>, ... 0.00% 0.000us 0.00% 0.000us 0.000us 17.000us 100.00% 17.000us 17.000us 1
# ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
2 Likes