I see a related topic regarding my question here, but i could not find my answer there so i ask it here.
lets say im using the pretrained vgg and i want to extract the features from some specific layers.
Here is what i should do:
# Load the Vgg:
vgg16 = models.vgg16(pretrained=True)
# cut the part that i want:
new_base = (list(vgg16.children())[:-1])[0]
# if i print the new_base, i will have:
# Sequential(
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (1): ReLU(inplace)
# (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (3): ReLU(inplace)
# (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)
# (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# ...
# ...
# (30): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)
# )
# Then here is my feature extractor function:
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
self.submodule = submodule
def forward(self, x):
outputs = []
for name, module in self.submodule._modules.items():
x = module(x)
if name in self.extracted_layers:
outputs += [x]
return outputs + [x]
Lets say i have my input and i called it Input.
ands I am interested in having the computed features from the
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
and
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
layers.
Can you please tell me what should i do and how should i do that?
When i call FeatureExtractor, should i do like: FeatureExtractor(new_base,???).
I am not sure how should i the ???, also dont know how to send my infput to the forward function.
My assumption is that because i loaded the pretrained network, i dont have train it anymore, because it already has the pretrain weights, and i can just use them, right?
I also so have second question that is related to this,
is there anyway that i can just do: y = new_base[n](x)
, where n is the number of layer that i am interested it and x is the input. and get the output y?
Thanks