Hey guys, I want to select tensor B along some dimension by some condition of A, specifically,
assume:
A = [[0, 1, 1],
[1, 3, 4],
[2, 0, 0]] (3x3)
B = torch.randn(3x3x10)
I want to select dim 2 of B (1x1x10) by some condition of A, say, A[i,j] == 1,
mask = A==1:
[[False, True, True],
[True, False, False],
[False, False, False]]
if mask[i,j] is True, select B[i, j, :], (and ‘move it to left’)
we will get:
result=
[[B[0, 1, :], B[0, 2, :]],
[B[1, 0, :]]
[]]
and add some padding:
result=
[[B[0, 1, :], B[0, 2, :], PAD],
[B[1, 0, :], PAD, PAD]
[PAD, PAD, PAD]]
I wrote a loop version like:
def select_and_pad(
input_repr, ner_label_indices, indices, max_len,
):
"""Selecting tensors by specified ner indices along some axis.
Args:
input_repr (tensor): [bs, seq_len, hidden]
ner_label_indices (tensor): [bs, seq_len]
indices (list): specific entity indices
max_entity_num (int): max seq len after padding
Returns:
selected_repr (tensor): [bs, max_len, hidden]
"""
bs, seq_len, hidden_dim = input_repr.shape
selected_repr = []
for i in range(bs):
curr = []
for j in range(seq_len):
curr_label = ner_label_indices[i][j].item()
if curr_label in indices:
curr.append(input_repr[i, j, :])
# truncate
if len(curr) > max_len:
curr = curr[:max_len]
# pad zeros
curr.extend([
torch.zeros(hidden_dim)
for _ in range(max_len - len(curr))
])
selected_repr.append(torch.stack(curr, dim=0))
selected_repr = torch.stack(selected_repr, dim=0)
return selected_repr
A = torch.tensor([[0,1,1], [1,3,4]])
B = torch.arange(1, 7).view(2,3).unsqueeze(2).repeat(1,1,10)
select_indices = [1] # select B[i, j, :] where A[i, j] == 1
print(f"A: {A.shape}\n", A, "\n")
print(f"B: {B.shape}\n", B, "\n")
selection = select_and_pad(B, A, select_indices, max_len=3)
print(f"result: {selection.shape}\n", selection)
the result is:
A: torch.Size([2, 3])
tensor([[0, 1, 1],
[1, 3, 4]])
B: torch.Size([2, 3, 10])
tensor([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]],
[[4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
[5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
[6, 6, 6, 6, 6, 6, 6, 6, 6, 6]]])
result: torch.Size([2, 3, 10])
tensor([[[2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[4., 4., 4., 4., 4., 4., 4., 4., 4., 4.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])
Question is, can I do this with some torch build-in operations?
Thanks!