How to use and finetune XLS-R wave2vec2 feature extractor in PyTorch?

Please help me to some resource that shows how to use XLS-R (wave2vec2) in PyTorch training setup. I’ve looked all over but found no guide or sample code for:

  1. PyTorch dataset and data loader to load waveforms in batch (can be integrated with #2 below)
  2. Model part 1 - XLS-R pretrained feature extractor
  3. Model downstream task - something simple like language classification (just depends on data/labels)
  4. Training, validation, metrics, …
    bundle = torchaudio.pipelines.WAV2VEC2_XLSR53
    model = bundle.get_model().to(args['device'])

    # Resample audio to the expected sampling rate
    sample_path = args['trainsets'][0] + os.listdir(args['trainsets'][0])[0]

    sample_waveform, sample_samplerate = torchaudio.load(sample_path)
    sample_waveform =['device'])
    waveform = torchaudio.functional.resample(sample_waveform,
                                              sample_samplerate, bundle.sample_rate)

    # Extract acoustic features
    with torch.inference_mode():
        features, _ = model.extract_features(waveform)

    for x, element in enumerate(features):
        print(x, element)

The above is what I have , it applies the feature extractor to the sample waveform and produces 24 tensors (the extracted features I believe). I need to apply it to the whole dataset in batches, and add a few more layers to the model, and fine-tune the whole model. Once this works, I also need to use the new XLS-R (larger one not in PyTorch yet). Any help is appreciated.


Hi @Wafaa_Wardah, The 24 tensors are the outputs of the 24 transformer layers in the XLS-R model. You can choose one of them, or weighted-sum the tensors as the final feature. The shape should be (batch, frame, feature_dim).

For 1) PyTorch dataset and data loader to load waveforms in batch, you can check the code of LibriSpeech dataset in torchaudio for dataset, and the LibriSpeech ASR training recipe for DataLoader usage.

For 3), I’m not sure what downstream task you want to apply, you can check S3PRL repository, which aims to benchmark self-supervised learning models in various downstream tasks.

For 4), I recommend using PyTorch-Lightning, which provides simplified APIs for training, validation, and tensorboard support.