aside from the solution kindly provided by @ptrblck , you can also do sth like this :
class MyCustomResnet18(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
resnet18 = models.resnet18(pretrained=pretrained)
# here we get all the modules(layers) before the fc layer at the end
# note that currently at pytorch 1.0 the named_children() is not supported
# and using that instead of children() will fail with an error
self.features = nn.ModuleList(resnet18.children())[:-1]
# Now we have our layers up to the fc layer, but we are not finished yet
# we need to feed these to nn.Sequential() as well, this is needed because,
# nn.ModuleList doesnt implement forward()
# so you cant do sth like self.features(images). Therefore we use
# nn.Sequential and since sequential doesnt accept lists, we
# unpack all the items and send them like this
self.features = nn.Sequential(*self.features)
# now lets add our new layers
in_features = resnet18.fc.in_features
# from now, you can add any kind of layers in any quantity!
# Here I'm creating two new layers
self.fc0 = nn.Linear(in_features, 256)
self.fc0_bn = nn.BatchNorm1d(256, eps = 1e-2)
self.fc1 = nn.Linear(256, 256)
self.fc1_bn = nn.BatchNorm1d(256, eps = 1e-2)
# initialize all fc layers to xavier
for m in self.modules():
if isinstance(m, nn.Linear):
torch.nn.init.xavier_normal_(m.weight, gain = 1)
def forward(self, input_imgs):
# now in forward pass, you have the full control,
# we can use the feature part from our pretrained model like this
output = self.features(input_imgs)
# since we are using fc layers from now on, we need to flatten the output.
# we used the avgpooling but we still need to flatten from the shape (batch, 1,1, features)
# to (batch, features) so we reshape like this. input_imgs.size(0) gives the batchsize, and
# we use -1 for inferring the rest
output = output.view(input_imgs.size(0), -1)
# and also our new layers.
output = self.fc0_bn(F.relu(self.fc0(output)))
output = self.fc1_bn(F.relu(self.fc1(output)))
return output
You can get fancy and add new methods for your network (e.g. for freezing, unfreezing different parts of your network, that can come handy in finetuning)