Hi,
I’m bit puzzled on how scatter_add
(https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_add_) works.
Problem
Why can’t I reproduce the same result with for-loops?
self[index[i][j][k]][j][k] += other[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] += other[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] += other[i][j][k] # if dim == 2
Examples
I’m testing the method progressively starting from a
1-D tensor.
a = tf.Variable([1,2,3], dtype=tf.float32)
b = torch.tensor([1,2,3], dtype=torch.float32)
indices = np.array([0,0,1,2,1,0,2], dtype=np.int)
updates = np.array(list(range(len(indices))), dtype=np.float32)
b.scatter_add(0, torch.from_numpy(indices), torch.from_numpy(updates))
# tensor([ 7., 8., 12.])
2-D tensor.
array = np.array([
[1,2,3],
[4,5,6]
], dtype=np.float32)
b = torch.tensor(array, dtype=torch.float32)
indices = np.array([
[0,0], [0,1]
], dtype=np.int)
updates = np.array(
[[1,2], [3,4]],
dtype=np.float32)
Here I run into an error:
b.scatter_add(0, torch.from_numpy(indices), torch.from_numpy(updates))
# Isn't the above the same as below?
for i in range(2):
for j in range(2):
b[indices[i,j]][j] += updates[i,j]
# tensor([[5., 4., 3.],
# [4., 9., 6.]])
The following does not raise an error:
b.scatter_add(1, torch.from_numpy(indices), torch.from_numpy(updates))
# tensor([[11., 4., 3.],
# [10., 17., 6.]])
# However, the result is not as same as
for i in range(2):
for j in range(2):
b[i, indices[i,j]] += updates[i,j]
# tensor([[ 8., 4., 3.],
# [ 7., 13., 6.]])
Thanks!