I have a tensor like this example (x):
tensor([[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]]])
Using a tensor mask like this:
tensor([[0, 1, 0, 0],
[0, 0, 1, 0],
[1, 1, 0, 0]])
I want to gather the elements in the 3D tensor along the Z axis where the mask is 1, producing a result like this:
tensor([[ 9, 21],
[ 2, 14],
[10, 22],
[ 7, 19]])
So far, I’ve managed to generate results in a very ugly way:
x = x.transpose(0, 2).reshape(-1, x.shape[0])
idx = mask.transpose(0, 1).flatten().nonzero()
result = x[idx].squeeze()
Is there a better way to do it?
Thanks!