Hi,
I’m trying to implement efficient complex multiply in pytorch, since this is the bottleneck in my application.
Here’s the code in Python:
def complex_mult_torch(X, Y):
assert X.shape[-1] == 2 and Y.shape[-1] == 2, 'Last dimension must be 2'
return torch.stack(
(X[..., 0] * Y[..., 0] - X[..., 1] * Y[..., 1],
X[..., 0] * Y[..., 1] + X[..., 1] * Y[..., 0]),
dim=-1)
How would I implement this in C++ with ATen/CUDA?
I’ve looked at https://github.com/pytorch/extension-cpp but it’s not clear to me how to do slicing with ATen Tensors.
For CUDA, how would I handle broadcasting if X and Y are not of the same shape?