Expanding multidimensional tensor by non-singleton dimension

I am trying to expand a [200, 176, 2] binary mask to select from [200, 176, 14] tensor, so that first 7 elements from the tensor’s 3rd dimension (size 14) would be selected by mask[:, :, 0] and last 7 elements by mask[:, :, 1]. E.g. if my mask at third dimension is [0,1] then a selection is made as if it was [0,0,0,0,0,0,0,1,1,1,1,1,1,1]. I managed to solve it by this piece of lengthy code, but I imagine there must be a shorter and more straightforward way (and also without using Numpy as I intend to process this on GPU).

Goal in short: use [200, 176, 2] binary mask b to select from [200, 176, 14] tensor a

My current code (works as expected, but very lengthy):

# tensor to select from
a = torch.rand([200,176,14])

# mask
b = torch.zeros([200,176,2], dtype=torch.uint8)

# split mask by the last dimension
mask_parts = torch.split(b, 1, dim=2)

# first part, size torch.Size([200, 176, 1])
mask1 = mask_parts[0]
# expand to torch.Size([200, 176, 7])
mask1 = mask1.expand(-1,-1,7)

# second part, identical processing to the first
mask2 = mask_parts[1]
mask2 = mask2.expand(-1,-1,7)

# join masks, get [200, 176, 14]
mask = torch.cat((mask1, mask2), 2)

# now the goal - use the mask to select elements from a
result = a[mask]

Is there a better way to achieve the same result or is it OK?

To expand [200, 176, 2] mask to size [200, 176, 14], you can do the following:

new_mask = b.unsqueeze(-1).repeat(1, 1, 1, 7).view(200, 176, -1)
print( torch.all((new_mask.float() - mask.float()) == 0)) # 1
1 Like

Thank you! Expanding into 4th dimension and back was something I couldn’t think of myself!