Memory/Performance of Training Wav2vec2 Model

I don’t know how the extract_features method is implemented, but your approach of returning the desired activation/feature sounds valid.
Alternatively, you could also try to use forward hooks as described here assuming you are interested in an output of a specific nn.Module.
If you don’t want to train the model, you could also wrap the forward pass into a with torch.no_grad() statement to delete the intermediate activations which would otherwise be needed for the gradient computation.

1 Like