Assign values to a pytorch tensor

I have a 2D tensor and I have the following problem:

a=tensor([[1296, 1295, 1292,    4, 1311,    4, 1293,    2],
        [1297, 1295, 1292, 1404, 1294,    4, 1293,    2]]

I need to mask all values greater than 1292, also I want to mask values in sorted order by incrementing values. What I want is:

tensor([[3, 2, 1292,    4, 5,    4, 1,    2],
        [4, 3, 1292, 5, 2,    4, 1,    2]]

How can I do this?

def func(x):
    _, indices = torch.topk(x, k=x.size(1), dim=1)
    _, indices = torch.sort(x, dim=1)
    mask = x <= 1292
    num_greater = mask.sum(dim=1, keepdim=True)
    
    _, s_indices = torch.sort(indices, dim=1)
    s_indices -= num_greater - 1
    
    return torch.where(mask, x, s_indices)

x = torch.tensor([[1296, 1295, 1292, 4, 1311, 4, 1293, 2],
                  [1297, 1295, 1292, 1401, 1294, 4, 1293, 2]], dtype=torch.int64)
func(x)

# tensor([[   3,    2, 1292,    4,    4,    4,    1,    2],
#         [   4,    3, 1292,    5,    2,    4,    1,    2]])

How to do the same thing when values are repeated on a dimension?

a=tensor([[1296, 1295, 1292,    4, 1311,    4, 1293,    2],
        [1297, 1295, 1292, 1404, 1293,    4, 1293,    2]]

a should become:

a = tensor([[3, 2, 1292,    4, 5,    4, 1,    2],
        [3, 2, 1292, 4, 1,    4, 1,    2]]

I don’t think that can be done with the above code, without introducing more complexity. You can also write a for loop that does the same in brute force manner.