What does the scatter_ function do in layman terms?

In case anyone is looking at this after the documentation, here is an explanation for how they arrived at the first result:

>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
        [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
        [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
        [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])

The scatter says “send the elements of x to the following indices in torch.zeros, according to ROW-WISE (dim 0)”. In layman’s terms, this is saying, for each element in the original
[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004], [ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732] tensor, we specify a row index (0,1 or 2) to send it to in the tensor we are scattering into.

https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_

21 Likes