What does the scatter_ function do in layman terms?

(Kwahi) #1

I read the examples in the documentation and the explanation but i still can’t understand what this function does.

(Alban D) #2

Hi,

It is setting particular values of a tensor at the provided indices. The value that you set can either be always the same or provided by another tensor wich is the same size as the indices tensor.
For example, if you want to do a one hot encoding of a 1D tensor of labels, you can start with a 2D tensor filled with zeros and then scatter 1s according to the labels of each entry.

1 Like
(John) #3

In case anyone is looking at this after the documentation, here is an explanation for how they arrived at the first result:

>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
        [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
        [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
        [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])

The scatter says “send the elements of x to the following indices in torch.zeros, according to ROW-WISE (dim 0)”. In layman’s terms, this is saying, for each element in the original
[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004], [ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732] tensor, we specify a row index (0,1 or 2) to send it to in the tensor we are scattering into.

https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_