I am trying to expand a [200, 176, 2] binary mask to select from [200, 176, 14] tensor, so that first 7 elements from the tensor’s 3rd dimension (size 14) would be selected by mask[:, :, 0] and last 7 elements by mask[:, :, 1]. E.g. if my mask at third dimension is [0,1] then a selection is made as if it was [0,0,0,0,0,0,0,1,1,1,1,1,1,1]. I managed to solve it by this piece of lengthy code, but I imagine there must be a shorter and more straightforward way (and also without using Numpy as I intend to process this on GPU).
Goal in short: use [200, 176, 2] binary mask b to select from [200, 176, 14] tensor a
My current code (works as expected, but very lengthy):
# tensor to select from
a = torch.rand([200,176,14])
# mask
b = torch.zeros([200,176,2], dtype=torch.uint8)
# split mask by the last dimension
mask_parts = torch.split(b, 1, dim=2)
# first part, size torch.Size([200, 176, 1])
mask1 = mask_parts[0]
# expand to torch.Size([200, 176, 7])
mask1 = mask1.expand(-1,-1,7)
# second part, identical processing to the first
mask2 = mask_parts[1]
mask2 = mask2.expand(-1,-1,7)
# join masks, get [200, 176, 14]
mask = torch.cat((mask1, mask2), 2)
# now the goal - use the mask to select elements from a
result = a[mask]
Is there a better way to achieve the same result or is it OK?