Coming from tensorflow (and liking it here, this is a great community!) I am probably just approaching this the wrong way.
I am trying to write a really simple feature extraction script.
Given an input image (or a stack of images) I just want to compute some features using a pretrained vgg19 and dump them into a binary file. No backprop, no bells or whistles, really.
Thing is, which features exactly is a parameter to my script.
Now, in tensorflow, this would essentially be something along the lines of
logits, end_points = vgg19(inputs)
return [end_points[k] for k in targets]
where targets are conv1_1 or fc6 for example.
How do I do something similar in pytorch?
What I have inferred so far would be possible:
create a subclass of nn.Module
inside that module, load vgg19
essentially re-build the entirety of the original vgg19 model using the submodules of the normal vgg19, and adding names to the submodule keys to be able to address them by those names later
define forward() to do the forward pass until the last target, keeping the intermediate results in a list and returning them.
This seems awfully complicated for what I’m trying to achieve, so I’m probably approaching this the wrong way… help, please?
class VGG19FeatureExtractor(nn.Module):
def __init__(self, net, targets):
super(VGG19FeatureExtractor, self).__init__()
self.features = net.features
self.classifier = net.classifier
if isinstance(targets, int):
self.targets = {targets} # TODO: make this nice strings
self.is_list = False
else:
self.targets = set(targets) # TODO: make this nice strings
self.is_list = True
max_index = len(self.features) + len(self.classifier)
for t in targets:
if not 0 <= t < max_index:
raise ValueError(f'Invalid target index {t}. Should be within [0, {max_index})')
def forward(self, x):
results = list()
for i, module in enumerate(self.features):
x = module(x)
if i in self.targets:
print(f'{i}\t{module}')
results.append(x.clone())
if len(results) == len(self.targets):
if self.is_list:
return results
else:
return results[0]
x = x.view(x.shape[0], -1)
i_ = len(self.features) # classifier numbers are numbered starting from this
for i, module in enumerate(self.classifier):
x = module(x)
if i + i_ in self.targets:
results.append(x.clone())
print(f'{i+i_}\t{module}')
if len(results) == len(self.targets):
if self.is_list:
return results
else:
return results[0]
raise RuntimeError('should have returned before here. what happened?!')
This class takes a vgg19 network as input, as well as a single target or a collection of targets (integers).
It computes the target(s) and returns them as a list if appropriate.
While this seems to be doing what I want, I have several issues with it:
it is incredibly verbose for what it does.
it uses integers as the targets instead of e.g. conv1_2 or pool5
it relies on the self.features and self.classifier attributes of the vgg networks (won’t work with e.g. ResNet)
the double for-loop makes my heart hurt. This might be mitigated somewhat with a simple Flatten module though.
EDIT:
and yes, I put that string interpolation in there out of principle. Change it if you want to use it in python2