Torchaudio leads to have out of memory error

Below is the code where I am showing that what are the places where torchaudio fails and other libraries works:

import torchaudio
import torch
import librosa
import random

def add_noise(waveform, noise_waveform, SNR):
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)
        if noise_waveform.dim() == 1:
            noise_waveform = noise_waveform.unsqueeze(0)
            
        # get lenght of waveform and noise_waveform
        waveform_len = waveform.shape[-1]
        noise_waveform_len = noise_waveform.shape[-1]

        if noise_waveform_len > waveform_len:
            offset = random.randint(0, noise_waveform_len - waveform_len)
            noise_waveform = noise_waveform[..., offset:offset + waveform_len]
        elif noise_waveform_len < waveform_len:
            noise_waveform = torch.cat([noise_waveform, torch.zeros((noise_waveform.shape[0], waveform_len - noise_waveform_len))], dim=-1)

        # print(f"waveform shape after padding noise: {waveform.shape}")
        # print(f"noise_waveform shape after padding noise: {noise_waveform.shape}")
        # use torchaudio.functional.add_noise to add noise to the audio
        waveform = torchaudio.functional.add_noise(waveform, noise_waveform, snr=SNR)

        return waveform 

N = 100  # Number of iterations

for i in range(N):
    # Generate random audio file
    length = torch.randint(1, 10, (1,))  # Random length between 1 and 10 seconds
    sampling_rate = torch.randint(8000, 48000, (1,))  # Random sampling rate between 8000 and 48000 Hz
    waveform = torch.randn(int(length.item() * sampling_rate.item()))

    # Convert to another sampling rate
    new_sampling_rate = torch.randint(8000, 48000, (1,))  # Random new sampling rate between 8000 and 48000 Hz
    # resampled_waveform = torchaudio.transforms.Resample(sampling_rate.item(), new_sampling_rate.item())(waveform) # this leads to memory burst
    # resampled_waveform = torchaudio.functional.resample(waveform, sampling_rate.item(), new_sampling_rate.item()) # this leads to memory burst
    resampled_waveform = torch.tensor(scipy.signal.resample(waveform.numpy(), int(waveform.size(0) * new_sampling_rate.item() / sampling_rate.item())), dtype=torch.float32)

    # do some augmentation to the resampled waveform
    #1. pitch_shift
    pitch_shift_factor = torch.randint(-4, 4, (1,)).item()
    # resampled_waveform = torchaudio.functional.pitch_shift(resampled_waveform, new_sampling_rate.item(), pitch_shift_factor) # this leads to memory burst
    # pitch shift using librosa
    resampled_waveform = torch.tensor(librosa.effects.pitch_shift(resampled_waveform.numpy(), sr = new_sampling_rate.item(), n_steps = pitch_shift_factor), dtype=torch.float32) # it's working

    #2. time_stretch
    time_stretch_factor = torch.rand(1).item()
    # resampled_waveform, _ = torchaudio.functional.speed(resampled_waveform, new_sampling_rate.item(), time_stretch_factor) # this leads to memory burst
    # time stretch using librosa
    resampled_waveform = torch.tensor(librosa.effects.time_stretch(resampled_waveform.numpy(), rate=time_stretch_factor), dtype=torch.float32) # it's working

    #3. volume gain
    # volume_gain = torch.rand(1).item()
    # resampled_waveform = resampled_waveform * volume_gain # it's working
    gain_db = torch.randint(-10, 10, (1,)).item()
    resampled_waveform = torchaudio.functional.gain(resampled_waveform, gain_db=gain_db) # it's working as well

    #4. adding noise
    noise_waveform = torch.randn_like(resampled_waveform)
    SNR = torch.randint(0, 30, (1,)).item()
    SNR = torch.tensor([SNR], dtype=torch.float32)

    print(f"waveform shape: {waveform.shape}, resampled_waveform shape: {resampled_waveform.shape}")
    print(f"noise waveform shape: {noise_waveform.shape}")
    print(f"SNR shape: {SNR.shape}, SNR: {SNR.item()}")
    resampled_waveform = add_noise(resampled_waveform, noise_waveform, SNR)


    print(f"waveform shape: {waveform.shape}, resampled_waveform shape: {resampled_waveform.shape}")
    # Print information
    print(f"Iteration {i+1}: Original sampling rate = {sampling_rate.item()} Hz, New sampling rate = {new_sampling_rate.item()} Hz")

Please go through the code to regenerate the Memory Error.
Thanks!

There are three places where I found it’s giving memory error (see the above code):

  1. Resampling
  2. pitch shift
  3. speed (time stretch)

Hi @ptrblck, I will be waiting for your response.
Thanks again!