Change tensor values by index greater(or less) than some value

Hey everyone!

I have a question of selecting/changing values by some specific indices.
Here is an example:

import torch

"""
The original tensor is like this:

t = torch.tensor(
    [[99, 2, 1, 3, 4],
     [1, 2, 3, 99, 4],
     [4, 1, 99, 2, 3]]
)

I want to change t's value with rule:
for each row, set elements which is before 99 to 1; 0 otherwise,

so my target is:

tensor(
    [[0, 0, 0, 0],
    [1, 1, 1, 0], 
    [1, 1, 0, 0]]
)

following is the for-loop way, my question is, can I do this more efficiently?

Thanks!

"""

# create a tensor of shape [3, 4]
t = torch.tensor(
    [[99, 2, 1, 3, 4],
     [1, 2, 3, 99, 4],
     [4, 1, 99, 2, 3]]
)

# get 99's position of each row
position_of_zero = torch.nonzero(t == 99)[:, 1]

print(position_of_zero)

# for-loop
for row, zero_pos in zip(t, position_of_zero):
    for i in range(zero_pos):
        row[i] = 1
    
    for i in range(zero_pos + 1, len(row)):
        row[i] = 0

print(t)

# drop 99
t = t[t != 99].view(-1, 4)
print(t)

The desired target doesn’t match the description:

for each row, set elements which is before 99 to 1; 0 otherwise,

since the 99 in the first row is set to 0, while the others are not.
Assuming you would like to set the value at 99 to 1, this code should work and would avoid the loop:

t2 = torch.tensor(
    [[99, 2, 1, 3, 4],
     [1, 2, 3, 99, 4],
     [4, 1, 99, 2, 3]]
)

idx = torch.cumsum((t2 == 99), 1)
idx[t2==99] = 0
res = (~idx.bool()).long()
print(res)
> tensor([[1, 0, 0, 0, 0],
          [1, 1, 1, 1, 0],
          [1, 1, 1, 0, 0]])

Depending on the shape of your tensor, it might be faster or not, so I would recommend to profile both approaches.

since the 99 in the first row is set to 0 , while the others are not.

Sorry for the unclear description, actually I want to just DROP all the 99s,

I slightly modified you solution, and it works for me too!

t2 = torch.tensor(
    [[99, 2, 1, 3, 4],
     [1, 2, 3, 99, 4],
     [4, 1, 99, 2, 3]]
)

idx = torch.cumsum((t2 == 99), 1)
idx[t2==99] = 2
idx = idx[idx != 2].view(3, -1)

res = (~idx.bool()).long()
print(res)
> tensor([[1, 0, 0, 0, 0],
          [1, 1, 1, 1, 0],
          [1, 1, 1, 0, 0]])

Thanks for your help!