I don’t know how the extract_features
method is implemented, but your approach of returning the desired activation/feature sounds valid.
Alternatively, you could also try to use forward hooks as described here assuming you are interested in an output of a specific nn.Module
.
If you don’t want to train the model, you could also wrap the forward pass into a with torch.no_grad()
statement to delete the intermediate activations which would otherwise be needed for the gradient computation.
2 Likes