Interleaving a set of channels during concatenation

Hi,

I am trying to perform a quaternion space concatenation which requires the four dimensions [r,i,j,k] to be concatenated. According to quaternion theory, we cannot apply the torch.cat function directly as they would mess up the components. So I managed to perform this action using the quaternion_concat function which is adapted from here.

For example, tensor_1 and tensor_2 are two tensors that needs to be concatenated and it has 16 channels each. That means it has 4 channels of r,i,j,k respectively in a stacked fashion. I use the torch.chunk function to separate these and concatenate them separately. Is there any way I could perform this in a better way?

import torch

def quarternion_concat(x, dim=2):
    output = [[] for i in range(4)]
    for _x in x:
        sp = torch.chunk(_x, 4, dim=dim)
        for i in range(4):
            output[i].append(sp[i])

    final = []
    for o in output:
        o = torch.cat(o, dim)
        final.append(o)

    return torch.cat(final, dim)


tensor_1 = torch.randn((1, 16, 64, 64), requires_grad=False)
tensor_2 = torch.randn((1, 16, 64, 64), requires_grad=False)

tensor3 = quarternion_concat([tensor_1, tensor_2], dim=1)

Any help is appreciated.
Thank you,
Shreyas Kamath