For example I have 3d tensor like this:
a = torch.ones((3, 3, 3))
a[:, 1, 1] = 2
a[:, 2, 2] = 5
And I have a 2d “mask” like this:
b = torch.zeros((3, 3))
b[1, 1] = 1
b[2, 2] = 1
And I want to get a list of 3d vectors from ‘a’ by mask ‘b’:
the output should contain two vectors:
[[2, 2, 2], [5, 5, 5]]
I tried to use some constructions like torch.masked_select, but it always return 1d-tesnror where not save “vectorized” order of elements. it returns tensor like this:
[2, 5, 2, 5, 2, 5]
How can I get correct result using pytorch operations?
Thanks in advance!