Suppose I have a tensor A with the following shape:

```
torch.Size([5, 16, 5000, 3])
```

I also have a mask of the same shape:

```
torch.Size([5, 16, 5000, 3])
```

If I apply this mask M directly to the tensor A via

```
A = A[M]
```

I end up with a flattened tensor with single dimension.

However, I would like to mask out only along dimension 2. In other words, I would like to get a tensor of the shape

```
torch.Size([5, 16, 5000 - N, 3])
```

where N is the number of entries for which mask M is False.

What is the way of doing this?