Hi, I am doing segmantic segmentation with large class imbalances(5 classes). So I am passing in a weight array into my loss as in loss_fn = torch.nn.CrossEntropyLoss(weight=loss_weights)

Now since, it is very hard to assign these weights, I am trainning for 200 epochs and then setting the loss_weights then as trainable parameters. But when I do this I get the error :
RuntimeError: the derivative for ‘weight’ is not implemented

How can I get around this and any other suggestions to deal with such class imbalance?

for semantic segmentation I wouldn’t use cross entropy. See Sudre et al. (see also Crum et al. )
that has a generalized dice coefficient, weighted for class imbalance. It’s a good start. There exist TF implementation of the generalized dice coefficient, that you can easily port in pytorch, here.

I think it depends a lot on the relative weighting of the two losses: they clearly take different range of values, therefore, one may dominate the other.