Actually I want to implement the CenterLoss in pytorch, the code in tf is something like
centers_update_op = tf.scatter_sub(centers, labels, diff)
which is
labels -> [batch_size, 1]
centers->[class_nums, emb_size]
diff->[batch_size, emb_size]
for idx, label in enumerate(labels):
centers[label,:] += diff[idx]
I try to use the scatter_add api in pytorch
centers.scatter_add_(0, labels, diff)
however, it only affect the first element of the emb_size. I have to use the for loop to implement it. I think it would be rather slow. Is there any api can use to do that? Thanks!