Select By Index

Hi,
I have segmentation maps with size = [5,64,64] where batch size is 5, the spatial size is 64x64.
Given a feature map F with the same size 64x64 and N channels.

Let select locations in the segmentation map whose labels are v. Now I want to select features in F based on the selected locations whose labels are v in the segmentation map.

Here’s how it looks like:

seg = torch.randint(0,5,size=(5,64,64))

F = torch.rand(5,25,64,64)

locations = seg==1
F[locations]??

Examples:

seg with batch_size=2, spatial_size=2x2=4
seg = [[[0,1,2,3],[1,2,1,2]
F with 3 channels, the size is (2,3,4)
F = [[2,3,4,3],
     [4,1,2,4],
     [4,0,5,1]],

     [1,2,0,4],
     [0,1,10,3],
     [1,2,0,5]],
location = seg == 1

output = F[location]= [[3,1,0],[1,0,1],[0,10,0]]

Your example isn’t coherent with your explanation above. Given your explanation, seg and F size should be respectively (2, 2, 2) and (2, 3, 2, 2):

seg = [[[0, 1],
        [2, 3]],
       [[1, 2],
        [1, 2]]]

mask = np.array(seg)

print(mask.shape)  # (2, 2, 2)

F = [[[[2, 3],
       [4, 3]],
      [[4, 1],
       [2, 4]],
      [[4, 0],
       [5, 1]]],
     [[[1, 2],
       [0, 4]],
      [[0, 1],
       [10, 3]],
      [[1, 2],
       [0, 5]]]]

features = np.array(F)

print(features.shape)  # (2, 3, 2, 2)

output = features[mask == 1] # yields an error because mask and features haven't the same dimension

But in that case I don’t understand what you want to do. If you have a batch of segmentation masks, and a batch of features maps, don’t you want to treat each sample from the batch independantly ? Based on your example, it seems you want the positive indices across all the batch, which is implicitely a fusion of the batched segmentation masks. Is it really what you want to do ?

For instance let’s forget this batch issue and assume you have one segmentation mask and one set of features:

mask     = torch.randint(0, 5, size=(64, 64))
features = torch.rand(25, 64, 64)

You can then use torch.masked_select:

class_mask = torch.where(mask == 1, True, False)
output = features.masked_select(class_mask)
1 Like

Hi, that is exactly what I’m looking for. Thank you!