Nn.utils.rnn.pad_sequence with values from sequence?

I have a list of irregular tensors like:

tensor_list = [
    torch.tensor([    # shape [4, 3]
        [8, 8, 3],
        [7, 2, 9],
        [2, 4, 3],
        [3, 3, 3]
    ]),
    torch.tensor([    # shape [2, 3]
        [1, 3, 4],
        [5, 5, 5]
    ]), 
    torch.tensor([    # shape [3, 3]
        [3, 3, 3],
        [7, 7, 9],
        [2, 1, 3]
    ])
]

To pad these tensors to a regular shape, the following code works:

regular_tensor = nn.utils.rnn.pad_sequence(tensor_list, batch_first=True, padding_value=0)

And the regular tensor will be of shape [3, 4, 3]:

torch.tensor([
    [
        [8, 8, 3],
        [7, 2, 9],
        [2, 4, 3],
        [3, 3, 3]
    ],
    [
        [1, 3, 4],
        [5, 5, 5],
        [0, 0, 0],    # <-- padded value
        [0, 0, 0]     # <-- padded value
    ],
    [
        [3, 3, 3],
        [7, 7, 9],
        [2, 1, 3],
        [0, 0, 0]    # <-- padded value
    ])
])

But I would like the padded values to take triplet values from the already existing tensors (first/last/random doesn’t matter) preferably through random choice. So that the regular tensor will be something like:

torch.tensor([
    [
        [8, 8, 3],
        [7, 2, 9],
        [2, 4, 3],
        [3, 3, 3]
    ],
    [
        [1, 3, 4],
        [5, 5, 5],
        [5, 5, 5],    # <-- randomly selected value from 1 row above
        [1, 3, 4]     # <-- randomly selected value from 3 rows above
    ],
    [
        [3, 3, 3],
        [7, 7, 9],
        [2, 1, 3],
        [7, 7, 9]     # <-- randomly selected value from 2 rows above
    ])
])

How is this best achieved?

I made a function for this. It’s pretty specific and probably not optimal but works:

def pad_colors_random_choice(sequences):
    longest_colors = max([len(s) for s in sequences])
    
    out_tensor = torch.empty([len(sequences), longest_colors, 3], device=sequences[0].device)
    for i, colors in enumerate(sequences):
        n_colors = len(colors)
        diff = longest_colors - n_colors
        pad = colors[torch.randint(n_colors, [diff])]  
        out_tensor[i] = torch.cat([colors, pad])
        
    return out_tensor