I’m trying to utilize intermediate feature representations from different points along a CNN pipeline. Is it OK to create multiple feature extractors from the same pre-trained CNN? Will the parameter updates be performed correctly in this example (*see example below):
class Net(nn.Module): def __init__(self, dim_l1=256, dim_l2=512, dim_l3=512, dim_fc=512): super(Net, self).__init__() vgg = models.vgg19(pretrained=True) self.full_feature_extractor = nn.Sequential(*list(vgg.features)) self.local_feature_extractor1 = nn.Sequential(*list(vgg.features)[:15]) self.local_feature_extractor2 = nn.Sequential(*list(vgg.features)[:-13]) self.local_feature_extractor3 = nn.Sequential(*list(vgg.features)[:-6]) self.fc1 = nn.Linear(dim_fc, dim_fc) self.fc2 = nn.Linear(dim_l1 * 64 + dim_l2 * 16 + dim_l3 * 4, dim_fc) self.fc3 = nn.Linear(dim_fc, 10) @staticmethod def some_function(x1, **kwargs): ... ... return some_output @staticmethod def flatten(x): return x.view(-1, int(np.prod([i for i in x.size()[1:]]))) def forward(self, x): local_features1 = self.local_feature_extractor1(x) local_features2 = self.local_feature_extractor2(x) local_features3 = self.local_feature_extractor3(x) end_features = self.fc1(self.global_feature_extractor(x).squeeze()) some_output1 = self.some_function(local_features1, **kwargs1) some_output2 = self.some_function(local_features2, **kwargs2) some_output3 = self.some_function(local_features3, **kwargs3) some_output4 = self.some_function(end_features, **kwargs3) g = torch.cat([self.flatten(some_output1), self.flatten(some_output2), self.flatten(some_output3), self.flatten(some_output4)], dim=1) out = F.relu(self.fc2(g)) out = F.relu(self.fc3(out)) return F.softmax(out, dim=1)
When I check the parameters I am seeing the following output:
net = Net()
for i, name in enumerate(list(net.named_parameters())):
So it looks like the parameters from VGG are all being associated with the biggest feature extractor component (which all other feature extractors are subsets of in this case). So they would each be updated once per backprop?
Is this a reasonable approach? Please advise if an alternative would be better.