How to do the left-filling operation on a pytorch tensor?

For example, I have a tensor A:

A = torch.tensor(

Given an index list idx_list = torch.tensor([3,1,5]), I want to get a tensor B from A, so that
B[i][j] = A[i][idx_list[i]] if j < idx_list[i]; else: B[i][j] = A[i][j]

In other words, for row i of A, I want to make the entries whose column j < idx_list[i] be the value of A[i][idx_list[i]]

So, given the idx_list = torch.tensor([3,1,5]), the resulting B should be:

B = [[4,4,4,4,5,6,7,8],

Could anyone give some suggestion how to implement a function func(A, idx_list) so that f can return B?

Thanks very much

This code should work:

A = torch.tensor([[1,2,3,4,5,6,7,8],

idx_list = torch.tensor([3,1,5])

values = A[torch.arange(A.size(0)), idx_list]
idx = A < values[:, None]
A[idx] = values[:, None].expand(values.size(0), A.size(1))[idx]

Thank you very much! It works.