Question about fine tuning a fcn_resnet101 model with 2 classes

I want to fine tune the pre-trained fcn resnet segmentation model with my own data set which only contains two classes.

I am following this tutorial, but I want to understand how the code knows which 2 classes I am referring to? So lets say my custom data set contains images with only two classes, background and car for example, when I set num_classes = 2, how does the code know that I’m specifically referring to background and car, instead of any of the other 20 classes it was originally trained on like cat, dog, horse, bicycle, etc?

Can it just tell by the specific color in the masked images? Or does it have to be specified somewhere in the code when fine tuning?

I am new to pytorch and machine learning so I would like to gain a better understanding of this so I can finetune the fcn resnet model correctly.

1 Like

Basically resnet is trained in imagenet dataset. while doing transfer learning you’re freezing the last couple of layers and using your dataset to train on that model.

Thanks so much for your response. So basically you are saying that there is no need to specify which two classes you are referring to right?

yeah. If you have any doubt in the future message me.

As can be seen in the tutorial num_classes is used to initialize a new linear layer, which acts as the classification layer.
The output activations will provide num_classes outputs, which represent the logit for each class, i.e. output[0] would be the logit for class0.

If you define your Dataset, you would have to make sure that each sample is mapped to a valid class index.

E.g. in your use case with num_classes = 2, your target should contain the values 0 and 1.
The mapping between these class indices and the meaning is made by you.
I.e. if you only care about the background and car class, you would have to make sure the you are only using samples containing these two classes.
The mapping is arbitrary. So you could define background = 0 and car = 1 or vice versa.

Thanks so much for your detailed response!

I have a few questions to gain more clarity and understanding from your response:

  1. " If you define your Dataset , you would have to make sure that each sample is mapped to a valid class index."

When you are talking about mapping, is this done simply by using the masked image? Like in my masked images, all the background pixels are black, and the ones that belong to the car class are blue. Is using that enough to map the classes correctly? Or do you need to do something else besides that to map it? You mentioned output[0], but I didn’t see anything in the tutorial where you index an output and map it to a class.

  1. “if you only care about the background and car class, you would have to make sure the you are only using samples containing these two classes.”

In my image data set, I only have two labels/classes ‘background’ or ‘car’. Basically everything that is not a car is marked as the background. In some of those unlabeled images, I have objects in the background like a ‘chair’ or ‘bird’ that are classes in the original pre-trained model, but in my masked images I mark them as background because I don’t care about those classes. Would doing this (masking objects that belonged to a different class in the original pre-trained model as background) confuse the model and cause a loss in accuracy?

Basically I want to clarify if what you mean by I should “only have samples containing those two classes” means that the original unlabeled images should not have objects that belong to other classes other than ‘background’ and ‘car’. Or if it’s ok to have images like these, as long as the masked/labeled images only have ‘background’ and ‘car’ class?

Most likely not. Base on your description, it seems you are dealing with color images as masks.
For a usual segmentation use case, the mask should be a LongTensor with the shape [batch_size, height, width] containing class indices.

Have a look at this post to see, how a mapping between the color values and your class indices can be created.