For example, I have a tensor A:
A = torch.tensor(
[[1,2,3,4,5,6,7,8],
[5,6,7,8,9,10,11,12],
[3,5,6,7,8,8,10,11]]
Given an index list idx_list = torch.tensor([3,1,5]), I want to get a tensor B from A, so that
B[i][j] = A[i][idx_list[i]] if j < idx_list[i]; else: B[i][j] = A[i][j]
In other words, for row i of A, I want to make the entries whose column j < idx_list[i] be the value of A[i][idx_list[i]]
So, given the idx_list = torch.tensor([3,1,5]), the resulting B should be:
B = [[4,4,4,4,5,6,7,8],
[6,6,7,8,9,10,11,12],
[8,8,8,8,8,8,10,11]]
Could anyone give some suggestion how to implement a function func(A, idx_list) so that f can return B?
Thanks very much