Efficient complex-real transform

I need to solve a systems of linear equations in the complex domain, in a backpropagable way.
As stated in this research report (https://hal-ens-lyon.archives-ouvertes.fr/ensl-00125369v2/document), complex numbers can be represented as real 2x2 matrices and complex systems of linear equations can be solved using real solver on 2 times bigger matrices.

So, I’m converting vectors containing concatenated real and imaginary parts to these type of matrices. I have a working example :

import torch

def vec_to_cmat(vec, dim=-1):
    """ Expects tensor with shape 2*(N**2) on dim 'dim'"""
    # Get matrix size (before CR tranform)
    vec_shape = list(vec.shape)
    mtxc_size = int((vec_shape[dim]/2)**0.5)
    # Keep all remaining dimensions (matrix goes last)
    mat_shape = vec_shape + [2 * mtxc_size, 2 * mtxc_size]
    mat = vec.new_zeros(mat_shape)
    # Create indexes
    line_idx = torch.LongTensor([2 * m for _ in range(mtxc_size) for
                                 m in range(mtxc_size)])
    col_idx = torch.LongTensor([2 * n for n in range(mtxc_size) for
                                _ in range(mtxc_size)])
    # Split real and imaginary parts
    re, im = torch.split(vec, [mtxc_size**2, mtxc_size**2], dim=dim)
    # re, im = torch.chunk(vec, 2, dim=dim)  # works as well

    # Assign the re or im parts.
    mat[..., line_idx, col_idx] = re
    mat[..., line_idx+1, col_idx+1] = re
    mat[..., line_idx+1, col_idx] = im
    mat[..., line_idx, col_idx+1] = -im
    return mat

inp = torch.randn(10, 2, 18)
complex_mat = vec_to_cmat(inp)  # (10, 2, 6, 6)

I’ve been looking for a nicer way to assign the values without luck until now. Could anybody help please?
Also, this only works for dim=-1 for now, any idea (except transposing vec) on how to make it work for any dim