How to slice the tensor with the following mask in the easiest way?

I have a tensor, say representing some images, with a shape [batch_size, channel, height, width], and a mask tensor with a shape [batch_size, channel]. All elements in the mask tensor are bools. True means I want to slice that channel while False means I don’t. For example:

image = torch.randn(100, 3, 224, 224)
mask = torch.randn(100, 3).ge(0.5)

I wanted to slice image in the easiest way so I tried image[:, mask] but this didn’t work.

I realized that I have to add an extra condition: the number of Trues in dim = 1 is the same. Thus, the shape of the sliced tensor is [batch_size, num_of_true, height, width].

As you’ve already described, directly indexing image won’t work as the output tensor would have a variable shape in dim1. You could use image[mask] which will output [num_of_true, height, width], but unsure if you could work with this result.

Now I have to use for to do an element-wise work, but I’m not sure if using for slows the speed.

Pytorch can’t have tensors which are varied size on a given dim. Here are some ways you can approach this, depending on your desired outcome:

import torch

#apply the filter evenly on all batches the same
images=torch.randn(100, 3, 224, 224)
mask = torch.randn(3).ge(0.5)
print(images[:,mask,:,:].size())

#OR just treat all channels and batches together
mask2 = torch.randn(100*3).ge(0.5)
images=images.view(3*100, 224,224)

images = images[mask2,...]
print(images.size())

Thanks. The second way works. However, the size is fixed on a given dim so I thought pytorch has methods to slice this.

Not sure your use case, but have you considered just setting the values you don’t want to zero?

Perhaps you can share what your desired output should look like or your overall objective to help us narrow down a solution for you.

Ok I’ll say clearer. Now I use a smaller tensor to discribe my disired output.

images = torch.linspace(0, 14, 15)
images = images.view(5, 3, 1, 1).expand(5, 3, 10, 10) 
#images.shape = [5, 3, 10, 10] and 
#elements in images[i, j] are all 3 * i + j
mask = torch.tensor([[True, False, False],
                     [False, True, False],
                     [False, False, True],
                     [False, True, False],
                     [True, False, False]])

after a pytorch-style slicing the output output_tensor should be

output_tensor = torch.tensor([[[[0] * 10] * 10], 
                              [[[4] * 10] * 10],
                              [[[8] * 10] * 10],
                              [[[10] * 10] * 10],
                              [[[12] * 10] * 10]])
#output_tensor.shape = [5, 1, 10, 10]

Or if the mask is

mask = torch.tensor([[True, True, False],
                     [False, True, True],
                     [True, False, True],
                     [True, True, False],
                     [True, False, True]])

after a pytorch-style slicing the output output_tensor should be

output_tensor = torch.tensor([[[[0] * 10] * 10, [[1] * 10] * 10], 
                              [[[4] * 10] * 10, [[5] * 10] * 10],
                              [[[6] * 10] * 10, [[8] * 10] * 10],
                              [[[9] * 10] * 10, [[10] * 10] * 10],
                              [[[12] * 10] * 10, [[14] * 10] * 10]])
#output_tensor.shape = [5, 2, 10, 10]

Looks to me in both cases dim == 1 is the same within each respective batch. You could just modify the 2nd option to resize that dim dynamically. Here is how that might work:

images=torch.randn(100, 3, 224, 224)

mask1 = torch.randint(high=3, size=(100,3)) > 1 # for mask.size(1) == 1
mask2 = torch.randint(high=3, size=(100,3)) > 0 # for mask.size(1) == 2

b, c, h, w = images.shape

images=images.view(3*100, 224,224)

images1 = images[mask1.view(-1),...].view(b, -1, h, w)
images2 = images[mask2.view(-1),...].view(b, -1, h, w)

currently not at a computer to test the above, so please don’t mind the typos

Yes this is a great slicing method. I just think if there were a pytorch-style slicing way in one line (or say in one pair of brackets) then it would be more convenient.