Index fill on 2d tensor, index and fill value

scatter_ should work:

res = torch.zeros((4, 3))

fill = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5],
                     [0.6, 0.7, 0.8, 0.9, 1.0],
                     [1.1, 1.2, 1.3, 1.4, 1.5],
                     [1.6, 1.7, 1.8, 1.9, 2.0],
                     ])

index = torch.tensor([[0, 2, 2, 2, 2],
                      [0, 1, 2, 2, 2],
                      [1, 2, 2, 2, 0],
                      [2, 2, 1, 2, 2],
                      ], dtype=torch.int64)

res[index] = fill
res.scatter_(1, index, fill)
# tensor([[0.1000, 0.0000, 0.5000],
        # [0.6000, 0.7000, 1.0000],
        # [1.5000, 1.1000, 1.4000],
        # [0.0000, 1.8000, 2.0000]])
1 Like