How to select tensor's columns by mask?

For example I have 3d tensor like this:

a = torch.ones((3, 3, 3))
a[:, 1, 1] = 2
a[:, 2, 2] = 5

And I have a 2d “mask” like this:
b = torch.zeros((3, 3))
b[1, 1] = 1
b[2, 2] = 1

And I want to get a list of 3d vectors from ‘a’ by mask ‘b’:
the output should contain two vectors:
[[2, 2, 2], [5, 5, 5]]

I tried to use some constructions like torch.masked_select, but it always return 1d-tesnror where not save “vectorized” order of elements. it returns tensor like this:
[2, 5, 2, 5, 2, 5]

How can I get correct result using pytorch operations?

Thanks in advance!

Does advanced indexing (a[:, b]) work for you (note that you need a byte/uint8 tensor to index)?

a = torch.ones((3, 3, 3))
a[:, 1, 1] = 2
a[:, 2, 2] = 5

b = torch.zeros((3, 3), dtype=torch.uint8)
b[1, 1] = 1
b[2, 2] = 1

print(a[:, b])

Best regards

Thomas

3 Likes

Yes, it works good! Thank you!