Multi Label Classification in pytorch

Let’s take ResNet finetuning as an example:

class ResNet50(nn.Module):
    def __init__(self, num_classes):
        super(ResNet50, self).__init__()
        
        # Loading ResNet arch from PyTorch
        original_model = models.resnet50(pretrained=True)
        
        # Everything except the last linear layer
        self.features = nn.Sequential(*list(original_model.children())[:-1])
        
        # Get number of features of last layer
        num_feats = original_model.fc.in_features
        
        # Plug our classifier
        self.classifier = nn.Sequential(
        nn.Linear(num_feats, num_classes)
        )
        
        # Init of last layer
        for m in self.classifier:
            kaiming_normal(m.weight)

        # Freeze all weights except the last classifier layer
        # for p in self.features.parameters():
        #     p.requires_grad = False

    def forward(self, x):
        f = self.features(x)
        f = f.view(f.size(0), -1)
        y = self.classifier(f)
        return y

Is your question regarding using sigmoid here? :

    def forward(self, x):
        f = self.features(x)
        f = f.view(f.size(0), -1)
        y = self.classifier(f)
        y = F.sigmoid(y) # Is this better ?
        return y

Or at the level higher ?
Currently there is no difference.

Ideally, in the future you should use MultiLabelSoftMarginLoss during training once it is numerically stable and faster, see PyTorch issue 1516

Currently MultiLabelSoftMarginLoss in PyTorch is implemented in the naive way Sigmoid + Cross-Entropy separate pass while if it were fused it would be faster and more accurate.

The proper way is to use the log-sum-exp trick to simplify Sigmoid Cross Entropy (SCE) expression from this (after naive replacement of sigmoid into cross-entropy function):

SCE(x, y') = − 1/n ∑i(ti * (xi - ln(1 + e^xi)) + (1−ti) * -ln(1 + e^xi) )
ti (read target_i) being elements of y’

to this

SCE(x, y') = − 1/n ∑i(ti * xi - max(xi,0) - ln(1 + e^-|xi|) this is more numerically stable and much faster to compute.

Full explanation of each simplification steps in my own PyTorch-like framework here

Note: ln(1 + x) is also numerically instable if x << 1 (very inferior to 1), 1 + x will be simplified to 1 and ln(1) gives a result of 0 (catastrophic cancellation), even though when x is small ln(1 + x) ~= x, which means the network will wrongly stop training because no gradient. Numpy has the log1p function to avoid that but I don’t think PyTorch has it.

3 Likes