I have a 2D torch tensor (X) and a 1D index torch tensor (ind). I want to set values of each row in x (until corresponding index) to 2. like following:
X = [[1,0,4,5,6],
[3,6,7,10, 13],
[1,4,2,8,21]]
ind = [2,3, 1]
result = [[2,2,4,5,6],
[2, 2, 2, 10, 13]
[2, 4, 2, 8, 21]]
Is there an easy way to do so?
Thanks