Gathering elements along Z axis in 3D tensor following a mask

I have a tensor like this example (x):

tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12]],

        [[13, 14, 15, 16],
         [17, 18, 19, 20],
         [21, 22, 23, 24]]])

Using a tensor mask like this:

tensor([[0, 1, 0, 0],
        [0, 0, 1, 0],
        [1, 1, 0, 0]])

I want to gather the elements in the 3D tensor along the Z axis where the mask is 1, producing a result like this:

tensor([[ 9, 21],
        [ 2, 14],
        [10, 22],
        [ 7, 19]])

So far, I’ve managed to generate results in a very ugly way:

x = x.transpose(0, 2).reshape(-1, x.shape[0])
idx = mask.transpose(0, 1).flatten().nonzero()
result = x[idx].squeeze()

Is there a better way to do it?

Thanks!

You could create a bool mask and index the input tensor as:

x = torch.tensor([[[ 1,  2,  3,  4],
                   [ 5,  6,  7,  8],
                   [ 9, 10, 11, 12]],

                  [[13, 14, 15, 16],
                   [17, 18, 19, 20],
                   [21, 22, 23, 24]]])

mask = torch.tensor([[0, 1, 0, 0],
                     [0, 0, 1, 0],
                     [1, 1, 0, 0]]).bool()

res = x[mask.unsqueeze(0).expand(2, -1, -1)]
print(res)
> tensor([ 2,  7,  9, 10, 14, 19, 21, 22])

To get the posted output shape, you could use:

res = res.view(2, -1).transpose(0, 1)
print(res)
> tensor([[ 2, 14],
          [ 7, 19],
          [ 9, 21],
          [10, 22]])
1 Like