How to implement class weights for MSE loss?

Cant find in google anything about it.

Hi Had!

The short answer is that MSE loss is not a loss that would naturally
be used with classes. As such, pytorch’s MSELoss does not attempt
to implement any feature that might be considered analogous to
“class weights.”

The reason is that MSE loss is naturally a measure how much two
sets of real numbers differ from one another.

Class labels (even though implemented in pytorch as integer
categorical class labels) aren’t conceptually really numbers. You
can test whether two class labels are the same, “bird” = “bird”,
but you can’t ask how far apart two class labels are. That is, what
number would you assign to “bird” - “fish”? Is “bird” a better incorrect
label for “fish” than is “reptile”?

Having said all that, you can write your own MSE loss that takes
sample weights. If you have some scheme for assigning samples
to classes, you could then construct a vector of sample weights
(for each batch of samples) that are given by the class weight of
each sample’s corresponding class.

Good luck.

K. Frank

Well, in my case mse seems to be better.
This is not simple classification task, i use it in the middle of seq2seq model.
But seems like sample weight is only way.