Hi Federico!
Plan A:
Let’s say that you have a ColorClassifier
that does a (very) good job
of classifying color, and a ShapeClassifier
that does a (very) good job
of classifying shape.
Just put your input to be classified into both separately. If your
ColorClassifier
predicts “red” and your ShapeClassifier
predicts
“rectangle,” your “combined” classifier has predicted “red rectangle.”
If this works, just keep it simple and go with it.
Plan B:
Let’s say that the dataset you want to run your classifier on has some
structure that links shape and color together somehow. For example,
perhaps rectangles are more likely to be red and triangles are more likely
to be green, or perhaps you have images of red rectangles that also have
green and blue blotches in the background. You might then actually want
to combine your classifiers into one.
Here’s the general approach I would take (making some plausible
assumptions about the architectures of your classifiers):
Let’s assume that the final layer of ColorClassifier
is a Linear
with
out_features = 3
, and, for the sake of argument, with in_features = 8
,
that is, Linear (8, 3)
, and, similarly, let’s assume that the final layer of
ShapeClassifier
is Linear (12, 3)
.
So maybe the final section of ColorClassifier
is:
torch.nn.Sequential (
# various layers
torch.nn.Linear (in_features = 128, out_features = 8),
torch.nn.ReLU(),
torch.nn.Linear (in_features = 8, out_features = 3)
)
and, similarly, ShapeClassifier
looks like:
torch.nn.Sequential (
# various layers
torch.nn.Linear (in_features = 256, out_features = 12),
torch.nn.ReLU(),
torch.nn.Linear (in_features = 12, out_features = 3)
)
That is, the second-to-last layer of ColorClassifier
outputs a “feature
vector” of length 8 (more precisely, a batch of feature vectors of shape
[nBatch, 8]
) and the second-to-last layer of ShapeClassifier
outputs
a feature vector of length 12.
The basic idea is that you would concatenate these two feature vectors
together into a single feature vector of length 20, and then pass it through
a final classification layer, colorShapeLayer
.
The output of colorShapeLayer
could be six values, the first three for
the color classification, and the remaining three for the shape classification.
But my intuition is that you would be better off having colorShapeLayer
output nine values, one each for each of the nine combinations of three
colors and three shapes.
So your combined ColorShapeClassifier
might look something like this:
class ColorShapeClassifier (torch.nn.Module):
def __init__ (self):
super().__init__()
self.colorFront = ColorFrontEnd()
self.shapeFront = ShapeFrontEnd()
self.colorShapeLayer = torch.nn.Linear (in_features = 20, out_features = 9)
def forward (x):
yColor = self.colorFront (x)
yShape = self.shapeFront (y)
yBoth = torch.cat ((yColor, yShape), dim = -1)
z = torch.nn.functional.relu (yBoth)
output = self.colorShapeLayer (z)
return output
where all but the last layer of ColorClassifier
has been packaged as
the Module
ColorFrontEnd
, and similarly for ShapeFrontEnd
.
ColorShapeClassifier
can now be trained to make predictions for the
nine classes that correspond to the nine combinations of the three colors
and three shapes.
Initialize ColorFrontEnd
and ShapeFrontEnd
with pre-trained weights
from ColorClassifier
and ShapeClassifier
, respectively (assumed
to work well). The idea is that the pre-trained ColorFrontEnd
will output
a feature vector that encodes useful information about color, and similarly
for ShapeFrontEnd
. colorShapeLayer
starts out initialized randomly.
You would then freeze the pre-trained weights of ColorFrontEnd
and
ShapeFrontEnd
and just train the weights of colorShapeLayer
using
your color-shape, nine-class dataset.
(After training colorShapeLayer
for a while, you would probably want
to fine-tune the entire ColorShapeClassifier
model by unfreezing the
“front-end” weights and training all of the weights jointly for a while.)
(It might also make sense to have more than one Linear
layer after you
cat()
together the color and shape feature vectors or to use the output
of the third-to-last layers of ColorClassifier
and ShapeClassifier
as
your color and shape feature vectors that you cat()
together. I don’t know
of any way to decide among such variations this other than experimenting
with them.)
Good luck.
K. Frank