I need to solve a systems of linear equations in the complex domain, in a backpropagable way.
As stated in this research report (https://hal-ens-lyon.archives-ouvertes.fr/ensl-00125369v2/document), complex numbers can be represented as real 2x2 matrices and complex systems of linear equations can be solved using real solver on 2 times bigger matrices.
So, I’m converting vectors containing concatenated real and imaginary parts to these type of matrices. I have a working example :
import torch def vec_to_cmat(vec, dim=-1): """ Expects tensor with shape 2*(N**2) on dim 'dim'""" # Get matrix size (before CR tranform) vec_shape = list(vec.shape) mtxc_size = int((vec_shape[dim]/2)**0.5) # Keep all remaining dimensions (matrix goes last) vec_shape.pop(dim) mat_shape = vec_shape + [2 * mtxc_size, 2 * mtxc_size] mat = vec.new_zeros(mat_shape) # Create indexes line_idx = torch.LongTensor([2 * m for _ in range(mtxc_size) for m in range(mtxc_size)]) col_idx = torch.LongTensor([2 * n for n in range(mtxc_size) for _ in range(mtxc_size)]) # Split real and imaginary parts re, im = torch.split(vec, [mtxc_size**2, mtxc_size**2], dim=dim) # re, im = torch.chunk(vec, 2, dim=dim) # works as well # Assign the re or im parts. mat[..., line_idx, col_idx] = re mat[..., line_idx+1, col_idx+1] = re mat[..., line_idx+1, col_idx] = im mat[..., line_idx, col_idx+1] = -im return mat inp = torch.randn(10, 2, 18) complex_mat = vec_to_cmat(inp) # (10, 2, 6, 6)
I’ve been looking for a nicer way to assign the values without luck until now. Could anybody help please?
Also, this only works for
dim=-1 for now, any idea (except transposing vec) on how to make it work for any