Since someone else may also prefer not using hooks I will leave this code snippet here I found quite useful to extract from torchvision
models (adapted from here):
import torch
import torchvision
def feature_forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
feature_vector = torch.flatten(x, 1)
x = self.fc(feature_vector)
return x, feature_vector
torchvision.models.ResNet.forward.__code__ = feature_forward.__code__
model = torchvision.models.resnet34()
model.eval()
classes, features = model(torch.randn(1,3,224,224))
print(classes.shape)
print(features.shape)