4D boolean indexing with masks of lower dimmension


I have a 4D tensor A which is boolean and of shape (L,n,L,a), and two 1D tensors mask1 and mask2 both of shape (L,).

I would like to be able to do the following:
A[mask1,0,mask2,:3] = True
followed by
A[mask2,0,mask1,1:2] = True

Of course, the above does not work and results in shape mismatch. How can I accomplish this indexing/setting in a different way?

Hi David!

Could you clarify your question a bit?

What are the types of mask1 and mask2 and what would some example
values for them be?

As a simpler example, let’s say that A had shape [L, L]. In such a case
what, specifically, would you like A[mask1, mask2] = True to do?


K. Frank

1 Like

Hi K. Frank, thanks for your response.

mask1 and mask2 are both boolean tensors filled with some amount of True and False along their 1 dimension.

In response to the second part of your question, in the end, I want A to actually be a mask for another tensor. Specifically, in my problem I have a loss function that outputs a (L,n,L,a) tensor of floats, but many of those values in the tensor I don’t need contributing to the final loss. In order to only have only specific entries contribute to the loss, I’m trying to create a mask tensor (A) that is also 4D, which only has True in the particular positions that I want contributing to the loss. I will then multiply the 4D loss tensor and the 4D mask together to zero out the unneeded contributions. Hopefully this clarifies a bit of context.

Here’s also a more concrete example:

import torch 

L = 100
n = 5
a = 13

A = torch.zeros((L,n,L,a)).bool()

mask1 = torch.randint(0,2,(L,)).bool()
mask2 = torch.randint(0,2,(L,)).bool()

print(A[mask1,0].shape) # works fine 

print(A[mask1,0,mask2,:3]) # error 

Hi David!

It’s still not clear precisely what you want to do.

Speaking in terms of the simplified two-dimensional example, it looks like
you want to use mask1 to mask the rows of A and use mask2 to mask the

But how should this work?

Do you want to keep all rows where mask1 is True and all columns where
mask2 is True?

Do you only want to keep elements for which both mask1 and mask2 are

Or some other use case?

Here’s an illustration of these alternatives:

>>> import torch
>>> print (torch.__version__)
>>> mask1 = torch.tensor ([True, True, False, False])
>>> mask2 = torch.tensor ([True, False, True, False])
>>> t = torch.arange (16.0).reshape (4, 4) + 1.0
>>> t
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])
>>> t[torch.logical_and (mask1.unsqueeze (1), mask2.unsqueeze (0))]
tensor([1., 3., 5., 7.])
>>> t[torch.logical_or (mask1.unsqueeze (1), mask2.unsqueeze (0))]
tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 11., 13., 15.])
>>> t[torch.logical_not (torch.logical_and (mask1.unsqueeze (1), mask2.unsqueeze (0)))] = 0.0
>>> t
tensor([[1., 0., 3., 0.],
        [5., 0., 7., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])


K. Frank