Create a mask that is larger than the n-th quantile of each 2D tensor in a batch

I have a torch.Tensor of shape (2, 2, 2) (can be bigger), where the values are normalized within range [0, 1].

Now I am given a positive integer K, which tells me that I need to create a mask where for each 2D tensor inside the batch, values are 1 if it is larger than 1/k of all the values, and 0 elsewhere. The return mask also has shape (2, 2, 2).

For example, if I have a batch like this:

tensor([[[1., 3.],
         [2., 4.]],
        [[5., 7.],
         [9., 8.]]])

and let K=2, it means that I must mask the values where they are greater than 50% of all the values inside each 2D tensor.

In the example, the 0.5 quantile is 2.5 and 7.5, so this is the desired output:

tensor([[[0, 1],
         [0, 1]],
        [[0, 0],
         [1, 1]]])

I tried:

a = torch.tensor([[[0, 1],
                   [0, 1]],
                  [[0, 0],
                   [1, 1]]])
quantile = torch.tensor([torch.quantile(x, 1/K) for x in a])
torch.where(a > val, 1, 0)

But this is the result:

tensor([[[0, 0],
         [0, 0]],
        [[1, 0],
         [1, 1]]])

You could use broadcasting to get the desired result:

x = torch.tensor([[[1., 3.],
                   [2., 4.]],
                  [[5., 7.],
                   [9., 8.]]])

quantile = torch.tensor([torch.quantile(a, 1/2) for a in x])

res = torch.where(x > quantile[:, None, None], 1, 0)
print(res)
> tensor([[[0, 1],
           [0, 1]],

          [[0, 0],
           [1, 1]]])