class EncoderCNN(nn.Module):
def __init__(self):
super(EncoderCNN, self).__init__()
self.vgg = models.vgg16()
self.vgg.load_state_dict(torch.load(vgg_checkpoint))
self.vgg.classifier = nn.Sequential(
*(self.vgg.classifier[i] for i in range(6)))
def forward(self, images):
return self.vgg(images)
VGG net can be viewd as the combination of two sub-nets: feature extracting net and classifying net, and each of them is a nn.Sequential module. I just remove the last fc layer in classifying net by constructing a new nn.Sequential module with the pretrained parameters.
For your requirement, I guess you can do it like this:
def __init__(self):
super(EncoderCNN, self).__init__()
self.vgg = models.vgg16()
self.vgg.load_state_dict(torch.load(vgg_checkpoint))
self.vgg.features = nn.Sequential(
*(self.vgg.features[i] for i in range(30))
def forward(self, images):
return self.vgg.feature(images)
Sorry that this method looks so ugly. The nn.Sequential object does not support slice so I have to construct a new Sequential by list comprehension.
I guess the following codes do the same thing, right? pretrained_model = torchvision.models.vgg16(pretrained=True) modified_pretrained = nn.Sequential(*list(pretrained_model.features.children())[:-1]) # to relu5_3
This is because there is no module in the pre-trained model named as features. āfeaturesā is one of the modules of VGG(the initial example of this thread). To see the module names, just simple print your model.