I am using VGG pre-trained model for image classification. But I want to extract only features using VGG, not need to classify.
Can I do this?
import torch
import torchvision
import torch.nn as nn
vgg16 = torchvision.models.vgg16_bn(pretrained=True)
feature_extractor = nn.Sequential(*list(vgg16.classifier.children())[:-1])
vgg16.classifier = feature_extractor
image = torch.randn(1,3,226,226)
output = vgg16(image)
print(output.shape) # torch.Size([1, 4096])