I have a 2D tensor and I have the following problem:
a=tensor([[1296, 1295, 1292, 4, 1311, 4, 1293, 2],
[1297, 1295, 1292, 1404, 1294, 4, 1293, 2]]
I need to mask all values greater than 1292, also I want to mask values in sorted order by incrementing values. What I want is:
tensor([[3, 2, 1292, 4, 5, 4, 1, 2],
[4, 3, 1292, 5, 2, 4, 1, 2]]
How can I do this?
Kushaj
(Kushajveer Singh)
August 1, 2020, 10:05pm
#2
def func(x):
_, indices = torch.topk(x, k=x.size(1), dim=1)
_, indices = torch.sort(x, dim=1)
mask = x <= 1292
num_greater = mask.sum(dim=1, keepdim=True)
_, s_indices = torch.sort(indices, dim=1)
s_indices -= num_greater - 1
return torch.where(mask, x, s_indices)
x = torch.tensor([[1296, 1295, 1292, 4, 1311, 4, 1293, 2],
[1297, 1295, 1292, 1401, 1294, 4, 1293, 2]], dtype=torch.int64)
func(x)
# tensor([[ 3, 2, 1292, 4, 4, 4, 1, 2],
# [ 4, 3, 1292, 5, 2, 4, 1, 2]])
How to do the same thing when values are repeated on a dimension?
a=tensor([[1296, 1295, 1292, 4, 1311, 4, 1293, 2],
[1297, 1295, 1292, 1404, 1293, 4, 1293, 2]]
a should become:
a = tensor([[3, 2, 1292, 4, 5, 4, 1, 2],
[3, 2, 1292, 4, 1, 4, 1, 2]]
Kushaj
(Kushajveer Singh)
August 2, 2020, 2:02pm
#4
I don’t think that can be done with the above code, without introducing more complexity. You can also write a for
loop that does the same in brute force manner.