Pytorch multi class Classification where image belongs to multiple class

Hey, I would love to know how to implement a model which takes in a 3,224,224 chest X-ray image that can belong to multiple classes and classify the image correctly. Any help would be greatly appreciated.

Hi Veer!

You are asking about what is called multi-label, multi-class
classification problem.

Let me illustrate the typical use case with an example:

Let’s say that you have three classes: “pneumonia,” “artery blockage,”
and “tumor,” and that a single image can display all, some, or none
of these features – that is, it can belong to multiple classes at the
same time.

The input to your model would typically be a batch of images, that is,
a tensor of shape
[nBatch, nChannel = 3, height = 244, width = 244]. Your
target – your ground-truth labels – would be a batch of multi-label
labels of shape [nBatch, nClass = 3], and each value in the target
tensor would (in the case of “hard” labels) be 0.0 if the image were not
in the corresponding class and 1.0 if it were in the corresponding class.
(In the “soft”-label case, these values would be numbers ranging from
0.0 to 1.0 that would represent the probability that the image were in
the corresponding class.) For a given image all, some, or none of the
values along the nClass dimension could be 1.0. (In the soft-label
case there would be no requirement that, for a given image, the values
along the nClass dimension sum to 1.0.)

The last layer of your model would typically be a Linear with
out_features = nClass, you would interpret the output values as
raw-score logits for the image being in each of the classes, and you
would use BCEWithLogitsLoss as your loss criterion.

(For this kind of image classification problem, you would most likely
use a convolutional-neural-network (CNN) architecture where the
first several layers would be convolutions, only switching over to
fully-connected Linear layers at the end.)


K. Frank

Thank you so much. Very much appreciated.