how could I achieve the following function without for loop ?
Should I use scatter function ?
Thanks!
position_ids = torch.arange(500).reshape(10, 50)
prompt_pos = torch.arange(50).reshape(10, 5)
for index in range(10):
position_ids[index][prompt_pos[index][0]:] = prompt_pos[index][0]
I donāt see a way to do this with scatter(). You can, however, compute
an appropriate boolean mask and then use it to index into position_ids
(or, equivalently, use torch.masked_select()):