How to merge two model trained on different classes?

Hi Marsggbo!

I didn’t mean to suggest that you were using WeightedSampler
and WeightedLoss at the same time. As I understood it, in the
one calculation you reported you used just WeightedSampler,
while in the next you used just WeightedLoss. That makes
sense.

I don’t see any errors. (But I could be wrong …)

I do see an apparent inconsistency. With WeightedSampler

you get your labels from train_dataset.imgs, and then do
a numerical test, train_labels==i.

But with WeightedLoss

you get your labels from train_loader.dataset.imgs, and then
do a string test, labels==str(i).

This isn’t necessarily wrong, but seeing the two different kinds
of tests for essentially the same purpose looks potentially fishy.

On the larger issue of how to do the reweighting, if I understand
your code correctly, it looks like you are doing what I called
“full reweighting.” That is, if class 17 had 100 samples, while
class 22 had only 25 samples, you would weight class 22 four
times as heavily as class 17. (That is, in your code, a weight
of 0.01 for class 17 and a weight of 0.04 for class 22.)

In what I called flat weighting, the weights would be the same.
(For example, a weight of 1 for both classes.)

What I was suggesting before is that you try “partial reweighting.”
That is, you could weight your classes, for example, by the
square-root of the full weight. So, using the numbers from
the above example, class 17 would have a weight of 0.1, and
class 22 would have a weight of 0.2. Instead of weighting
class 22 four times as heavily as class 17, you would weight
it only twice as heavily.

More generally, you could introduce a parameter – call it r – that
runs from 0 to 1. When r = 0.0, you get flat weighting, and when
r = 1.0 you get full reweighting. In the above square-root
example, you would have r = 0.5. That is,

weight = (1. / count_labels.float())**r

My suggestion was that you calculate your training and testing
accuracies for several values of r running from 0 to 1. It is
possible that for some range of r values you get not only
better training accuracy, but better testing accuracy, as well.

The intuition would be that you get some of the benefits of
reweighting, so that you don’t “ignore” classes with only a
few images, but that you don’t reweight so heavily that you
over-fit classes with only a few images.

(You could also try this partially reweighting scheme with
WeightedLoss.)

Good luck.

K. Frank

1 Like