Hi,
I have writedown the center loss in tensorflow. I want to implement in pytorch. As some function are different between tensorflow and pytorch, anyone help me to sort out how to implement this code in pytorch. Thanks in advanced.
import tensorflow as tf
def get_center_loss(features, labels, alpha, num_classes):
len_features = features.get_shape()[1]
centers = tf.get_variable('centers', [num_classes, len_features], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
labels = tf.reshape(labels, [-1])
centers0=tf.unsorted_segment_mean(features,labels,num_classes)
EdgeWeights=tf.ones((num_classes,num_classes))-tf.eye(num_classes)
margin=tf.constant(100,dtype="float32")
norm = lambda x: tf.reduce_sum(tf.square(x), 1)
center_pairwise_dist = tf.transpose(norm(tf.expand_dims(centers0, 2) - tf.transpose(centers0)))
loss_0= tf.reduce_sum(tf.multiply(tf.maximum(0.0, margin-center_pairwise_dist),EdgeWeights))
centers_batch = tf.gather(centers, labels)
diff = centers_batch - features
unique_label, unique_idx, unique_count = tf.unique_with_counts(labels)
appear_times = tf.gather(unique_count, unique_idx)
appear_times = tf.reshape(appear_times, [-1, 1])
diff = diff / tf.cast((1 + appear_times), tf.float32)
diff = alpha * diff
loss_1 = tf.nn.l2_loss(features - centers_batch)
centers_update_op= tf.scatter_sub(centers, labels, diff)
return loss_0, loss_1, centers_update_op```