Given a 2d zeros tensor to fill:
res = torch.zeros((4, 3))
a 2d tensor with the values to fill and the same dim 0 of res (e.g. 4, 5):
fill = torch.tensor([
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.6, 0.7, 0.8, 0.9, 1.0],
[1.1, 1.2, 1.3, 1.4, 1.5],
[1.6, 1.7, 1.8, 1.9, 2.0],
])
and an index tensor of the same shape of fill (4, 5) and with index values in the range of dim 1 of res (0_2 in this case); with the max value (2) that can repeat more times per row and the other values that can appear 0 or 1 times per row:
index = torch.tensor([
[0, 2, 2, 2, 2],
[0, 1, 2, 2, 2],
[1, 2, 2, 2, 0],
[2, 2, 1, 2, 2],
], dtype=torch.int64)
I would like to fill res by row, given fill and index. I don’t care what values go into index 2 of res.
final_res = [
[0.1, 0.0, 0.3],
[0.6, 0.7, 0.9],
[1.5, 1.1, 1.4],
[0.0, 1.8, 1.6]
]