Let’s take ResNet finetuning as an example:
class ResNet50(nn.Module):
def __init__(self, num_classes):
super(ResNet50, self).__init__()
# Loading ResNet arch from PyTorch
original_model = models.resnet50(pretrained=True)
# Everything except the last linear layer
self.features = nn.Sequential(*list(original_model.children())[:-1])
# Get number of features of last layer
num_feats = original_model.fc.in_features
# Plug our classifier
self.classifier = nn.Sequential(
nn.Linear(num_feats, num_classes)
)
# Init of last layer
for m in self.classifier:
kaiming_normal(m.weight)
# Freeze all weights except the last classifier layer
# for p in self.features.parameters():
# p.requires_grad = False
def forward(self, x):
f = self.features(x)
f = f.view(f.size(0), -1)
y = self.classifier(f)
return y
Is your question regarding using sigmoid here? :
def forward(self, x):
f = self.features(x)
f = f.view(f.size(0), -1)
y = self.classifier(f)
y = F.sigmoid(y) # Is this better ?
return y
Or at the level higher ?
Currently there is no difference.
Ideally, in the future you should use MultiLabelSoftMarginLoss during training once it is numerically stable and faster, see PyTorch issue 1516
Currently MultiLabelSoftMarginLoss in PyTorch is implemented in the naive way Sigmoid + Cross-Entropy separate pass while if it were fused it would be faster and more accurate.
The proper way is to use the log-sum-exp trick to simplify Sigmoid Cross Entropy (SCE) expression from this (after naive replacement of sigmoid into cross-entropy function):
SCE(x, y') = − 1/n ∑i(ti * (xi - ln(1 + e^xi)) + (1−ti) * -ln(1 + e^xi) )
ti
(read target_i) being elements of y’
to this
SCE(x, y') = − 1/n ∑i(ti * xi - max(xi,0) - ln(1 + e^-|xi|)
this is more numerically stable and much faster to compute.
Full explanation of each simplification steps in my own PyTorch-like framework here
Note: ln(1 + x)
is also numerically instable if x << 1
(very inferior to 1), 1 + x will be simplified to 1 and ln(1) gives a result of 0 (catastrophic cancellation), even though when x is small ln(1 + x) ~= x
, which means the network will wrongly stop training because no gradient. Numpy has the log1p function to avoid that but I don’t think PyTorch has it.