I have a torch.Tensor
of shape (2, 2, 2)
(can be bigger), where the values are normalized within range [0, 1]
.
Now I am given a positive integer K
, which tells me that I need to create a mask where for each 2D tensor inside the batch, values are 1 if it is larger than 1/k
of all the values, and 0 elsewhere. The return mask also has shape (2, 2, 2)
.
For example, if I have a batch like this:
tensor([[[1., 3.],
[2., 4.]],
[[5., 7.],
[9., 8.]]])
and let K=2
, it means that I must mask the values where they are greater than 50% of all the values inside each 2D tensor.
In the example, the 0.5 quantile is 2.5
and 7.5
, so this is the desired output:
tensor([[[0, 1],
[0, 1]],
[[0, 0],
[1, 1]]])
I tried:
a = torch.tensor([[[0, 1],
[0, 1]],
[[0, 0],
[1, 1]]])
quantile = torch.tensor([torch.quantile(x, 1/K) for x in a])
torch.where(a > val, 1, 0)
But this is the result:
tensor([[[0, 0],
[0, 0]],
[[1, 0],
[1, 1]]])