How to fill selected values with index into another embedding?

Hi,

how could I achieve the following function without for loop ?
Should I use scatter function ?
Thanks!

position_ids = torch.arange(500).reshape(10, 50)
prompt_pos = torch.arange(50).reshape(10, 5)

for index in range(10):
    position_ids[index][prompt_pos[index][0]:] =  prompt_pos[index][0] 

@ptrblck, Could I get your help ? , Thanks~~~~~

Hi Albert!

I don’t see a way to do this with scatter(). You can, however, compute
an appropriate boolean mask and then use it to index into position_ids
(or, equivalently, use torch.masked_select()):

>>> import torch
>>> torch.__version__
'1.10.2'
>>>
>>> position_ids = torch.arange(500).reshape(10, 50)
>>> prompt_pos = torch.arange(50).reshape(10, 5)
>>>
>>> position_idsB = position_ids.clone()
>>>
>>> for index in range(10):
...     position_ids[index][prompt_pos[index][0]:] =  prompt_pos[index][0]
...
>>> msk = torch.arange (50).unsqueeze (0).expand (10, 50) >= prompt_pos[:, 0].unsqueeze (1)
>>>
>>> position_idsB[msk] = prompt_pos[:, 0].unsqueeze (1).expand (10, 50)[msk]
>>>
>>> torch.equal (position_ids, position_idsB)
True

Best.

K. Frank

2 Likes

Thanks! It works ~.~