How to use ResNet for multilabel classification

[PyTorch newbie] I have a dataset of images, where for each image we have 20+ attributes. I need to train a classifier which takes images as input, and returns the predicted attributes as output. As a first step, I would like to fine-tune ResNet. I need to complete the task using PyTorch. How do you recommend to proceed? Thank you!

1 Like

If you have an architecture that returns logits (often the case), you wouldn’t need to change the model architecture instead. I would recommend finetuning via a multilabel loss function e.g., MultiLabelSoftMarginLoss — PyTorch 1.11.0 documentation and iterating on the model/data from there.


Thanks for the hint! I’ll check what’s the loss function of ResNet and I’ll try to change is as you suggested.