Center loss in Pytorch

I have writedown the center loss in tensorflow. I want to implement in pytorch. As some function are different between tensorflow and pytorch, anyone help me to sort out how to implement this code in pytorch. Thanks in advanced.

import tensorflow as tf

def get_center_loss(features, labels, alpha, num_classes):
    len_features = features.get_shape()[1]

    centers = tf.get_variable('centers', [num_classes, len_features], dtype=tf.float32,
        initializer=tf.constant_initializer(0), trainable=False)
    labels = tf.reshape(labels, [-1])

    norm = lambda x: tf.reduce_sum(tf.square(x), 1)
    center_pairwise_dist = tf.transpose(norm(tf.expand_dims(centers0, 2) - tf.transpose(centers0)))
    loss_0= tf.reduce_sum(tf.multiply(tf.maximum(0.0, margin-center_pairwise_dist),EdgeWeights))

    centers_batch = tf.gather(centers, labels)

    diff = centers_batch - features
    unique_label, unique_idx, unique_count = tf.unique_with_counts(labels)
    appear_times = tf.gather(unique_count, unique_idx)
    appear_times = tf.reshape(appear_times, [-1, 1])
    diff = diff / tf.cast((1 + appear_times), tf.float32)
    diff = alpha * diff

    loss_1 = tf.nn.l2_loss(features - centers_batch)
    centers_update_op= tf.scatter_sub(centers, labels, diff)

    return loss_0, loss_1, centers_update_op```

Have a look at this

Thank you. Their implementation is little bit different that I want.

No offence but your question is very ambiguous. Are you trying to ask people to port your code to a PyTorch version? If yes, Iā€™d suggest you to do this by yourself, line by line. And if you get stuck, then ask :smile:

1 Like