Complex matrix multiplication

Given the documentation stating that

Operations on complex tensors (e.g., torch.mv() , torch.matmul() ) are likely to be faster and more memory efficient than operations on float tensors mimicking them.

I would expect matmul to be implemented for complex tensors, however when I try to execute the following:

a = torch.tensor([[1.4 + 3j, 2 + 5j], [1.4 + 3j, 2 + 5j]], dtype=torch.cfloat)
a @ a

I get RuntimeError: _th_addmm_out not supported on CPUType for ComplexFloat.

This also happens when using torch.matmul or torch.mm instead of the short operator.

Am I doing something wrong?

Edit: The error occurs on both CPU and GPU.

4 Likes

Update: matrix vector multiplication seems to work if explicitly done on a vector (torch.mv()) but cannot be deduced from e.g. multiplying shapes (2, 2) x (2, 1).

Would be really nice to know if this is expected as in “not implemented” or if this can be done differently.

1 Like