Custom concatenation of 2d torch.Tensor rows

Hi,

let’s say I have a tensor like this:

a = torch.randint(1, 9, (12, 3))

output:

tensor([[2, 1, 5],
        [2, 4, 7],
        [8, 4, 6],
        [3, 8, 5],
        [7, 3, 5],
        [6, 6, 8],
        [6, 8, 3],
        [5, 1, 4],
        [2, 4, 6],
        [4, 8, 7],
        [3, 3, 6],
        [1, 7, 8]])

and now I want to concatenate specific rows according to the following pattern:

triplets_list = []
for i in range(0, a.shape[0], 4):
    # Get quadruplet
    x1, x2, x3, x4 = a[i:i+4]
            
    # Create concatenated pairs from quadruplet
    pair_x1x2 = torch.cat([x1, x2], dim=0)
    pair_x1x3 = torch.cat([x1, x3], dim=0)
    pair_x3x4 = torch.cat([x3, x4], dim=0)
           
    # Stack concatenated tensors
    triplets_list.append(torch.stack([pair_x1x2, pair_x1x3, pair_x3x4], dim=0))

# Concatenate list of triplets
wanted_result = torch.cat(triplets_list, dim=0)

output:

tensor([[2, 1, 5, 2, 4, 7],
        [2, 1, 5, 8, 4, 6],
        [8, 4, 6, 3, 8, 5],
        [7, 3, 5, 6, 6, 8],
        [7, 3, 5, 6, 8, 3],
        [6, 8, 3, 5, 1, 4],
        [2, 4, 6, 4, 8, 7],
        [2, 4, 6, 3, 3, 6],
        [3, 3, 6, 1, 7, 8]])

So for every 4 rows (indexed from 0 to 3) I need to concatenate 0 with 1, 0 with 2 and 2 with 3. Is there a way to vectorize that operation? - for now I come up with non-vectorized presented above.