I’m trying to extend a binary classification problem to multi-class and implement it using PyTorch. The implementation of the binary problem is available in TensorFlow framework. The cost function is defined as
cost = norm * tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(logits=preds_sub, targets=labels_sub, pos_weight=pos_weight))
To train the network in my PyTorch implementation, I use CrossEntropyLoss
to compare the output scores of classes from the network with target labels. My question is if pos_weight
argument of TensorFlow’s weighted_cross_entropy_with_logits
is the same as weight
argument in PyTorch’s CrossEntropyLoss
. TF documentation states that
pos_weight
allows one to trade off recall and precision by up- or down-weighting the cost of a positive error relative to a negative error.
My understanding is that in the case of CrossEntropyLoss
, I’m not differentiating between false classifications (e.g. a class C being classified as A has the same weight as being classified as B) and for weight
argument of CrossEntropyLoss
to do the same thing as weighted_cross_entropy_with_logits
, I think that I need some sort of modification to the weights, like a weight matrix of size (class number) X (class number) but I’m not sure. I would appreciate any comments that might clarify this.