Is the scatter_add totally the same as tf version?

Actually I want to implement the CenterLoss in pytorch, the code in tf is something like

centers_update_op = tf.scatter_sub(centers, labels, diff)

which is

labels -> [batch_size, 1]
centers->[class_nums, emb_size]
diff->[batch_size, emb_size]
for idx, label in enumerate(labels):
    centers[label,:] += diff[idx]

I try to use the scatter_add api in pytorch

centers.scatter_add_(0, labels, diff)

however, it only affect the first element of the emb_size. I have to use the for loop to implement it. I think it would be rather slow. Is there any api can use to do that? Thanks!

I figure it out by using index_add_, codes like this

centers.index_add_(0, labels.flatten(), diff)