How to merge two model trained on different classes?

I have a imbalanced dataset, for some classes, there are 3000 images, but for some classes, there are only 40 images for each class. So my idea is to train a model A on classes which have more images, while train a model B on classes which have less images.
Suppose two models perform well on each data, now how can I merge these two models to predict all images?

You could use the output or the activations from the penultimate layer and feed it to another classifier.
However, this won’t solve your issue of imbalanced classes.

If you don’t want to train another classifier on top of your models, you could also use the predictions and try to find a threshold where to use the one model over the other.

Hello Marsggbo!

I don’t think that there is a particularly good way to merge
the models as you suggest.

The intuition is as follows:

Let’s assume – as is likely the case – that the early layers
of your trained A and B models are both extracting essentially
the same low-level features that are appropriate for use by
the later layers of the two models for performing the A and
B classifications. Even so, because the two models were trained
independently, the particular way in which those features are
encoded in the early layers will happen to be different, and
the A-model low-level feature encoding won’t happen to match
what the later layers of the B-model are expecting.

I think it makes more sense to use any of various established
techniquies for dealing with unbalanced data.

Your supposition that your B-model performs well on your B data
implies that you have enough B data, even if it’s a lot less
than your A data. If you simply don’t have enough B data, things
get harder.

One approach would be to sample your B data more heavily so that
you train your model with balanced samples.

To be concrete, let’s say that you have 5 A classes, each with
3000 sample images, and 5 B classes, each with only 40 images,
for 10 classes total. If you were to use a batch size of 50
for each training step, you could force your (otherwise randomly
sampled) batches to have exactly 5 images from each class. Of
course, as you train, you’ll start seeing repeats of your B-class
images much sooner than A-class repeats. (But that’s okay.)

You don’t have to force your batches to have exactly 5 images
from each class. You could sample so that you “probably” have
5 images. You might take a look at
torch.utils.data.sampler.WeightedRandomSampler
for doing this kind of thing.

It could also be better to not fully reweight your B data. For
example, you could partially reweight your B data so that a
typical batch has, say, 40 A images and 10 B images. (What
scheme works best depends on the details of your problem.)

Best regards.

K. Frank

Thanks so much for your detailed reply. I’ve tried WeightedSampler, but it seems not work. The results comparison is as follows:

  • Before use WeightedSampler
    Accuracy on training set: top1/5 = 53% / 82%
    Accuracy on testing set: top1/5 = 48% / 77%

  • After use WeightedSampler
    Accuracy on training set: top1/5 = 67% / 90%
    Accuracy on testing set: top1/5 = 30% / 69%

  • I also tried WeightedLoss, the result is:
    Accuracy on training set: top1/5 = 21% / 53%
    Accuracy on testing set: top1/5 = 22% / 53%

So what else should I do now to further improve the performance of my model? Thanks again!

Hi Marsggbo!

Interesting results …

A quick comment:

With WeightedSampler your training accuracy went up, but
your test accuracy went down. This suggests that you may be
over-fitting. I might guess that you are over-fitting your B-class
images, as you have relatively few of them and are presumably
reusing each one in your training more often.

You don’t say precisely how you are reweighting (neither with
WeightedSampler nor with WeightedLoss). If you are fully
reweighting (so that each batch has, on average, the same
number of A-class and B-class images), you might want to
try only partially reweighting. (If you use WeightedSampler,
but use a flat weight so that you don’t actually reweight, you
should get the same result as not using WeightedSampler
at all.) It’s possible (but no guarantee) that somewhere in
between flat weighting and full reweighting you get improved
results on your test accuracy as well as your training accuracy.

It’s worth a try, I think. (The same comments also apply to only
partially reweighting WeightedLoss.)

If that doesn’t help, it may be a consequence of having relatively
few B-class images. If so, you might consider using a network
with fewer layers / parameters in the hope it reduces the tendency
to over-fit. (Or train less or use smaller / bigger batches, etc.)

Best.

K. Frank

Sorry I didn’t explained clearly, I tried WeightedSampler and WeightedLoss respectively, not use them both. Here is my code

  • WeightedSampler
train_labels = np.array(train_dataset.imgs)[:,1]
count_labels = torch.tensor([(train_labels==i).sum() for i in range(100)]) # 100 classes
weight = 1. / count_labels.float()
samples_weight = torch.tensor([weight[i] for i in train_labels])
weight_sapmler = torch.utils.data.sampler.WeightedRandomSampler(samples_weight, len(train_labels))
train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size,
            sampler=weight_sapmler,
            num_workers=2, pin_memory=pin_memory)
  • WeightedLoss
labels = np.array(train_loader.dataset.imgs)[:,1]
weight = []
for i in range(100): # 100 classes
    weight.append(1./(labels==str(i)).sum().item())
weight = torch.Tensor(weight).float()
criterion = torch.nn.CrossEntropyLoss(weight=weight).to(device)

Is there any errors in my code?

Best regards,
marsggbo

Hi Marsggbo!

I didn’t mean to suggest that you were using WeightedSampler
and WeightedLoss at the same time. As I understood it, in the
one calculation you reported you used just WeightedSampler,
while in the next you used just WeightedLoss. That makes
sense.

I don’t see any errors. (But I could be wrong …)

I do see an apparent inconsistency. With WeightedSampler

you get your labels from train_dataset.imgs, and then do
a numerical test, train_labels==i.

But with WeightedLoss

you get your labels from train_loader.dataset.imgs, and then
do a string test, labels==str(i).

This isn’t necessarily wrong, but seeing the two different kinds
of tests for essentially the same purpose looks potentially fishy.

On the larger issue of how to do the reweighting, if I understand
your code correctly, it looks like you are doing what I called
“full reweighting.” That is, if class 17 had 100 samples, while
class 22 had only 25 samples, you would weight class 22 four
times as heavily as class 17. (That is, in your code, a weight
of 0.01 for class 17 and a weight of 0.04 for class 22.)

In what I called flat weighting, the weights would be the same.
(For example, a weight of 1 for both classes.)

What I was suggesting before is that you try “partial reweighting.”
That is, you could weight your classes, for example, by the
square-root of the full weight. So, using the numbers from
the above example, class 17 would have a weight of 0.1, and
class 22 would have a weight of 0.2. Instead of weighting
class 22 four times as heavily as class 17, you would weight
it only twice as heavily.

More generally, you could introduce a parameter – call it r – that
runs from 0 to 1. When r = 0.0, you get flat weighting, and when
r = 1.0 you get full reweighting. In the above square-root
example, you would have r = 0.5. That is,

weight = (1. / count_labels.float())**r

My suggestion was that you calculate your training and testing
accuracies for several values of r running from 0 to 1. It is
possible that for some range of r values you get not only
better training accuracy, but better testing accuracy, as well.

The intuition would be that you get some of the benefits of
reweighting, so that you don’t “ignore” classes with only a
few images, but that you don’t reweight so heavily that you
over-fit classes with only a few images.

(You could also try this partially reweighting scheme with
WeightedLoss.)

Good luck.

K. Frank

1 Like