I am working with wav audio files sampled at 44,100KHz which I need to load into torchaudio. My current implementation in PyTorch and PyTorch Lightning is as shown below…
import os
import random
from typing import Any, Dict, Optional
import torchaudio
from pytorch_lightning import LightningDataModule
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from typing_extensions import Self
class AudioDataset(Dataset):
"""
Handles the loading of a waveform.
"""
def __init__(self, data_dir: str, sample_rate: int,
num_frames: Optional[int] = None) -> None:
"""
Initializes the module.
Parameters
----------
data_dir : str
Path to folder with the audio files.
sample_rate : float
Target sample rate which the audio will be resampled to.
num_frames : int
Number of frames to limit loading to per sample. A random offset from the
start of the audio will be chosen when the number of frames is specified.
"""
super().__init__()
self.data_dir = os.path.expanduser(data_dir)
self.sample_rate = sample_rate
self.num_frames = num_frames
self.filenames = os.listdir(self.data_dir)
def __len__(self) -> int:
return len(self.filenames)
def __getitem__(self, index: int) -> Tensor:
filepath = os.path.join(self.data_dir, self.filenames[index])
if self.num_frames:
sample = self._load_random_audio_slice(filepath)
else:
sample = self._load_full_audio(filepath)
return sample
def _load_random_audio_slice(self, filepath: str) -> Tensor:
metadata = torchaudio.info(filepath)
frames_to_load = int(
(metadata.sample_rate / self.sample_rate) * self.num_frames)
frame_offset = random.randint(0, metadata.num_frames - self.num_frames)
waveform, sample_rate = torchaudio.load(filepath, num_frames=frames_to_load,
frame_offset=frame_offset,
normalize=True)
waveform = self._resample(waveform, sample_rate, self.sample_rate)
return waveform
def _load_full_audio(self, filepath: str) -> Tensor:
waveform, sample_rate = torchaudio.load(filepath, normalize=True)
waveform = self._resample(waveform, sample_rate, self.sample_rate)
return waveform
def _resample(self, x: Tensor, sample_rate: int, target_sample_rate: int) -> Tensor:
if sample_rate != target_sample_rate:
return torchaudio.functional.resample(x, sample_rate, target_sample_rate)
return x
class AudioDataModule(LightningDataModule):
"""
Encapsulates dataloading.
"""
def __init__(self, data_dir: str, sample_rate: int,
num_frames: Optional[int] = None, batch_size: int = 32,
num_workers: int = 4, pin_memory: bool = False, **kwargs: Any) -> None:
"""
Initializes the module.
Parameters
----------
data_dir : str
Path to folder with the audio files.
sample_rate : float
Target sample rate which the audio will be resampled to.
num_frames : int, default=None
Number of frames to limit loading to per sample. A random offset from the
start of the audio will be chosen when the number of frames is specified.
batch_size : int, default=32
Size of the mini-batch.
num_workers : int, default=4
Number of parallel processes to use for loading.
pin_memory : bool, default=False
Move data to pinned memory in the GPU.
**kwargs : Any
These additional keyword arguments are ignored they are used as a convinience
to allow initialization from a dictionary.
"""
super().__init__()
self.data_dir = data_dir
self.sample_rate = sample_rate
self.num_frames = num_frames
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
def setup(self, stage: str = None) -> None:
self.dataset = AudioDataset(self.data_dir, self.sample_rate, self.num_frames)
def train_dataloader(self) -> DataLoader:
return DataLoader(self.dataset, batch_size=self.batch_size,
num_workers=self.num_workers, pin_memory=self.pin_memory)
@classmethod
def from_dict(cls, dict: Dict[str, Any], **kwargs: Any) -> Self:
return cls(**dict, **kwargs)
I find that loading a batch size of 32 stereo channel 220,500 frames on a cpu, which is equivalent to 5s of audio takes very long, roughly about 30s to 1min.
My questions are:
- How can I speed up the dataloading process?
- Would storing the audio as numpy arrays instead of wav files help?
ps: I am new to the audio domain. If you spot any anti-patterns or things that I might be doing inefficiently kindly correct me.