Hmm... all is well for simple networks, like AlexNet.
What if I want to use something a bit more fancy, say ResNet-18, and wish to extract some intermediate representation?
I can fetch the model with
torchvision.models.resnet18(pretrained=True). I think the network graph can be reconstructed by forwarding a
Variable and checking who's the parent, but I wouldn't be sure about how to do this not by hand.
Moreover I think the
classes (mapping output class to label) is missing.
Can I register a hook to a specific child, and have it return its current output? In the old Torch I would have
model:get(myModuleIndex).output, or a more nested combination, and got
myModuleIndex from visual inspection of the model structure (using the
How do we hack these models?
OK, I've managed to display the graph with
from visualize import make_dot (here's
Now how can I reach inside the graph?
So, let's call the output
h_x (which is $h_\theta(x)$).
So, I can peek at my embedding if I do the following dirty sequence of operations (which I guess are not recommended).
x = Variable(torch.rand(1, 3, 224, 224))
h_x = resnet_18.forward(x)
last_view = h_x.creator.previous_functions
last_pool = last_view.previous_functions
embedding = last_pool.saved_tensors
embedding should correspond to the output of the last
Threshold, if I'm not mistaken. Still, this is not generally applicable to every functional block, since not all block cache the input.
So, I'm still after a better way to dissect these nets...