Hey everyone!
I have a question of selecting/changing values by some specific indices.
Here is an example:
import torch
"""
The original tensor is like this:
t = torch.tensor(
[[99, 2, 1, 3, 4],
[1, 2, 3, 99, 4],
[4, 1, 99, 2, 3]]
)
I want to change t's value with rule:
for each row, set elements which is before 99 to 1; 0 otherwise,
so my target is:
tensor(
[[0, 0, 0, 0],
[1, 1, 1, 0],
[1, 1, 0, 0]]
)
following is the for-loop way, my question is, can I do this more efficiently?
Thanks!
"""
# create a tensor of shape [3, 4]
t = torch.tensor(
[[99, 2, 1, 3, 4],
[1, 2, 3, 99, 4],
[4, 1, 99, 2, 3]]
)
# get 99's position of each row
position_of_zero = torch.nonzero(t == 99)[:, 1]
print(position_of_zero)
# for-loop
for row, zero_pos in zip(t, position_of_zero):
for i in range(zero_pos):
row[i] = 1
for i in range(zero_pos + 1, len(row)):
row[i] = 0
print(t)
# drop 99
t = t[t != 99].view(-1, 4)
print(t)