How to select 3D tensor from a 2D mask?

Hi I have a 3D tensor like (batch_size, seq_len, dim), and some of them in the 2nd dimension are zero-padded. And I have a mask generated by the lengths. I want to select the tensor by the mask. The behavior is like masked_select but returns a 2D tensor.

a = torch.rand((3, 3, 3))
a[1, 2] = 0
a[2, 2] = 0
a[2, 1] = 0
print(a)
tensor([[[0.7910, 0.4829, 0.7381],
         [0.9005, 0.2266, 0.5940],
         [0.8811, 0.8379, 0.9670]],

        [[0.3192, 0.9537, 0.1001],
         [0.5695, 0.0185, 0.2561],
         [0.0000, 0.0000, 0.0000]],

        [[0.8885, 0.0043, 0.3867],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]]])
b = torch.BoolTensor([[1, 1, 1],
                      [1, 1, 0],
                      [1, 0, 0]])

Expected output

tensor([[0.7910, 0.4829, 0.7381],
        [0.9005, 0.2266, 0.5940],
        [0.3192, 0.9537, 0.1001],
        [0.5695, 0.0185, 0.2561],
        [0.8885, 0.0043, 0.3867]])

Thanks!

1 Like

Hi,

you can use this mask for slicing, so c = a[b] should return your expected output

2 Likes

Thanks for this elegant solution