Set value of torch tensor up to some index

I have a 2D torch tensor (X) and a 1D index torch tensor (ind). I want to set values of each row in x (until corresponding index) to 2. like following:

X = [[1,0,4,5,6],
[3,6,7,10, 13],
[1,4,2,8,21]]
ind = [2,3, 1]

result = [[2,2,4,5,6],
[2, 2, 2, 10, 13]
[2, 4, 2, 8, 21]]

Is there an easy way to do so?
Thanks

Hi,

You can do this by creating binary mask.
First we create a matrix with same shape as X filled with zeros, then put 1s where index matches with ind tensor and finally by using cumsum, set 1 after previously located points. Finally, we can mask X:

X = torch.tensor(
    [[1,0,4,5,6],
     [3,6,7,10, 13],
     [1,4,2,8,21]])


ind = torch.tensor([2,3, 1])


mask = torch.zeros_like(X)
mask[(torch.arange(X.shape[0]), ind)] = 1
mask = 1 - mask.cumsum(dim=-1)
mask
# mask
# tensor([[1, 1, 0, 0, 0],
#         [1, 1, 1, 0, 0],
#         [1, 0, 0, 0, 0]])

X[mask.bool()] = # your desired value

Bests

4 Likes

great! thanks @Nikronic :slight_smile: