I have a `torch.Tensor`

of shape `(2, 2, 2)`

(can be bigger), where the values are normalized within range `[0, 1]`

.

Now I am given a positive integer `K`

, which tells me that I need to create a mask where for each 2D tensor inside the batch, values are 1 if it is larger than `1/k`

of all the values, and 0 elsewhere. The return mask also has shape `(2, 2, 2)`

.

For example, if I have a batch like this:

```
tensor([[[1., 3.],
[2., 4.]],
[[5., 7.],
[9., 8.]]])
```

and let `K=2`

, it means that I must mask the values where they are greater than 50% of all the values inside each 2D tensor.

In the example, the 0.5 quantile is `2.5`

and `7.5`

, so this is the desired output:

```
tensor([[[0, 1],
[0, 1]],
[[0, 0],
[1, 1]]])
```

I tried:

```
a = torch.tensor([[[0, 1],
[0, 1]],
[[0, 0],
[1, 1]]])
quantile = torch.tensor([torch.quantile(x, 1/K) for x in a])
torch.where(a > val, 1, 0)
```

But this is the result:

```
tensor([[[0, 0],
[0, 0]],
[[1, 0],
[1, 1]]])
```