ATen/CUDA implementation of complex multiply?


I’m trying to implement efficient complex multiply in pytorch, since this is the bottleneck in my application.
Here’s the code in Python:

def complex_mult_torch(X, Y):
    assert X.shape[-1] == 2 and Y.shape[-1] == 2, 'Last dimension must be 2'
    return torch.stack(
        (X[..., 0] * Y[..., 0] - X[..., 1] * Y[..., 1],
         X[..., 0] * Y[..., 1] + X[..., 1] * Y[..., 0]),

How would I implement this in C++ with ATen/CUDA?
I’ve looked at but it’s not clear to me how to do slicing with ATen Tensors.
For CUDA, how would I handle broadcasting if X and Y are not of the same shape?

You could use Tensor.narrow (I used that in Linear.cpp for the _trilinear op). I’m not sure I would expect a large speedup from doing this in C++ though.

Best regards