What does the scatter_ function do in layman terms?

I read the examples in the documentation and the explanation but i still can’t understand what this function does.

3 Likes

Hi,

It is setting particular values of a tensor at the provided indices. The value that you set can either be always the same or provided by another tensor wich is the same size as the indices tensor.
For example, if you want to do a one hot encoding of a 1D tensor of labels, you can start with a 2D tensor filled with zeros and then scatter 1s according to the labels of each entry.

19 Likes

In case anyone is looking at this after the documentation, here is an explanation for how they arrived at the first result:

>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
        [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
        [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
        [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])

The scatter says “send the elements of x to the following indices in torch.zeros, according to ROW-WISE (dim 0)”. In layman’s terms, this is saying, for each element in the original
[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004], [ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732] tensor, we specify a row index (0,1 or 2) to send it to in the tensor we are scattering into.

https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_

21 Likes

Sorry for asking but could you elaborate more on this?
I don’t understand the index part!?

Update :
I guess I understood it thanks to God!
0, 1, 2, 0 , 0 are the row indexes of our destination tensor .
this simply means, put the first element from source into the first row of the destination tensor (at the same exact index )
so 0th element in x which is 0.3992 goes to row 0. and in index 0.
1tst element in x which is 0.2908 goes to row 1 and in index 1
2nd element in x which is 0.9044 will go to row 2 and in index 2
3rd element in x which is 0.4850 will go to row 0 and in index 3
4th element in x which is ’ 0.6004 will go to row 0 and in index 4
and this goes on.
Basically, all elements in source will have their respective index, but they only are redirected to a specific row in the destination tensor!
Is this right?

15 Likes

@Shisho_Sama: almost. You got it right, given dim=0 (i.e. row in this case). But if dim=1 (i.e. column), the source row position would be maintained, and the indices would be used to specify the column position in destination.

Essentially, think of ‘dim’ with ‘index’, and keep everything else in output the same as in ‘src’.

3 Likes

The ‘_’ in scatter_ signifies that the operation has to happen in-place.

1 Like