Please help me to some resource that shows how to use XLS-R (wave2vec2) in PyTorch training setup. I’ve looked all over but found no guide or sample code for:
- PyTorch dataset and data loader to load waveforms in batch (can be integrated with #2 below)
- Model part 1 - XLS-R pretrained feature extractor
- Model downstream task - something simple like language classification (just depends on data/labels)
- Training, validation, metrics, …
bundle = torchaudio.pipelines.WAV2VEC2_XLSR53
model = bundle.get_model().to(args['device'])
# Resample audio to the expected sampling rate
sample_path = args['trainsets'][0] + os.listdir(args['trainsets'][0])[0]
sample_waveform, sample_samplerate = torchaudio.load(sample_path)
sample_waveform = sample_waveform.to(args['device'])
waveform = torchaudio.functional.resample(sample_waveform,
sample_samplerate, bundle.sample_rate)
# Extract acoustic features
with torch.inference_mode():
features, _ = model.extract_features(waveform)
print(len(features))
for x, element in enumerate(features):
print(x, element)
The above is what I have , it applies the feature extractor to the sample waveform and produces 24 tensors (the extracted features I believe). I need to apply it to the whole dataset in batches, and add a few more layers to the model, and fine-tune the whole model. Once this works, I also need to use the new XLS-R (larger one not in PyTorch yet). Any help is appreciated.
Thanks.