Is there any func to do "select and pad" in PyTorch?

Hey guys, I want to select tensor B along some dimension by some condition of A, specifically,

assume:
A = [[0, 1, 1],
[1, 3, 4],
[2, 0, 0]] (3x3)

B = torch.randn(3x3x10)

I want to select dim 2 of B (1x1x10) by some condition of A, say, A[i,j] == 1,

mask = A==1:
[[False, True, True],
[True, False, False],
[False, False, False]]

if mask[i,j] is True, select B[i, j, :], (and ‘move it to left’)

we will get:
result=
[[B[0, 1, :], B[0, 2, :]],
[B[1, 0, :]]
[]]

and add some padding:
result=
[[B[0, 1, :], B[0, 2, :], PAD],
[B[1, 0, :], PAD, PAD]
[PAD, PAD, PAD]]

I wrote a loop version like:

def select_and_pad(
    input_repr, ner_label_indices, indices, max_len,
):
    """Selecting tensors by specified ner indices along some axis.

    Args:
        input_repr (tensor): [bs, seq_len, hidden]
        ner_label_indices (tensor): [bs, seq_len]
        indices (list): specific entity indices
        max_entity_num (int): max seq len after padding

    Returns:
        selected_repr (tensor): [bs, max_len, hidden]
    """
    bs, seq_len, hidden_dim = input_repr.shape

    selected_repr = []

    for i in range(bs):
        curr = []
        
        for j in range(seq_len):
            curr_label = ner_label_indices[i][j].item()
            
            if curr_label in indices:
                curr.append(input_repr[i, j, :])

        # truncate
        if len(curr) > max_len:
            curr = curr[:max_len]

        # pad zeros
        curr.extend([
            torch.zeros(hidden_dim) 
            for _ in range(max_len - len(curr))
        ])
        
        selected_repr.append(torch.stack(curr, dim=0))

    selected_repr = torch.stack(selected_repr, dim=0)

    return selected_repr


A = torch.tensor([[0,1,1], [1,3,4]])
B = torch.arange(1, 7).view(2,3).unsqueeze(2).repeat(1,1,10)

select_indices = [1]  # select B[i, j, :] where A[i, j] == 1

print(f"A: {A.shape}\n", A, "\n")
print(f"B: {B.shape}\n", B, "\n")

selection = select_and_pad(B, A, select_indices, max_len=3)

print(f"result: {selection.shape}\n", selection)

the result is:

A: torch.Size([2, 3])
 tensor([[0, 1, 1],
        [1, 3, 4]]) 

B: torch.Size([2, 3, 10])
 tensor([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
         [3, 3, 3, 3, 3, 3, 3, 3, 3, 3]],

        [[4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
         [5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
         [6, 6, 6, 6, 6, 6, 6, 6, 6, 6]]]) 

result: torch.Size([2, 3, 10])
 tensor([[[2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
         [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[4., 4., 4., 4., 4., 4., 4., 4., 4., 4.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

Question is, can I do this with some torch build-in operations?
Thanks!