Faster dataloading for wav audio

I am working with wav audio files sampled at 44,100KHz which I need to load into torchaudio. My current implementation in PyTorch and PyTorch Lightning is as shown below…

import os
import random
from typing import Any, Dict, Optional

import torchaudio
from pytorch_lightning import LightningDataModule
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from typing_extensions import Self


class AudioDataset(Dataset):
    """
    Handles the loading of a waveform.
    """

    def __init__(self, data_dir: str, sample_rate: int,
                 num_frames: Optional[int] = None) -> None:
        """
        Initializes the module.

        Parameters
        ----------
        data_dir : str
            Path to folder with the audio files.
        sample_rate : float
            Target sample rate which the audio will be resampled to.
        num_frames : int
            Number of frames to limit loading to per sample. A random offset from the
            start of the audio will be chosen when the number of frames is specified.
        """
        super().__init__()

        self.data_dir = os.path.expanduser(data_dir)
        self.sample_rate = sample_rate
        self.num_frames = num_frames

        self.filenames = os.listdir(self.data_dir)

    def __len__(self) -> int:
        return len(self.filenames)

    def __getitem__(self, index: int) -> Tensor:
        filepath = os.path.join(self.data_dir, self.filenames[index])

        if self.num_frames:
            sample = self._load_random_audio_slice(filepath)
        else:
            sample = self._load_full_audio(filepath)

        return sample

    def _load_random_audio_slice(self, filepath: str) -> Tensor:
        metadata = torchaudio.info(filepath)
        frames_to_load = int(
            (metadata.sample_rate / self.sample_rate) * self.num_frames)
        frame_offset = random.randint(0, metadata.num_frames - self.num_frames)

        waveform, sample_rate = torchaudio.load(filepath, num_frames=frames_to_load,
                                                frame_offset=frame_offset,
                                                normalize=True)
        waveform = self._resample(waveform, sample_rate, self.sample_rate)

        return waveform

    def _load_full_audio(self, filepath: str) -> Tensor:
        waveform, sample_rate = torchaudio.load(filepath, normalize=True)
        waveform = self._resample(waveform, sample_rate, self.sample_rate)

        return waveform

    def _resample(self, x: Tensor, sample_rate: int, target_sample_rate: int) -> Tensor:
        if sample_rate != target_sample_rate:
            return torchaudio.functional.resample(x, sample_rate, target_sample_rate)

        return x


class AudioDataModule(LightningDataModule):
    """
    Encapsulates dataloading.
    """

    def __init__(self, data_dir: str, sample_rate: int,
                 num_frames: Optional[int] = None, batch_size: int = 32,
                 num_workers: int = 4, pin_memory: bool = False, **kwargs: Any) -> None:
        """
        Initializes the module.

        Parameters
        ----------
        data_dir : str
            Path to folder with the audio files.
        sample_rate : float
            Target sample rate which the audio will be resampled to.
        num_frames : int, default=None
            Number of frames to limit loading to per sample. A random offset from the
            start of the audio will be chosen when the number of frames is specified.
        batch_size : int, default=32
            Size of the mini-batch.
        num_workers : int, default=4
            Number of parallel processes to use for loading.
        pin_memory : bool, default=False
            Move data to pinned memory in the GPU.
        **kwargs : Any
            These additional keyword arguments are ignored they are used as a convinience
            to allow initialization from a dictionary.
        """
        super().__init__()

        self.data_dir = data_dir
        self.sample_rate = sample_rate
        self.num_frames = num_frames
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

    def setup(self, stage: str = None) -> None:
        self.dataset = AudioDataset(self.data_dir, self.sample_rate, self.num_frames)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.dataset, batch_size=self.batch_size,
                          num_workers=self.num_workers, pin_memory=self.pin_memory)

    @classmethod
    def from_dict(cls, dict: Dict[str, Any], **kwargs: Any) -> Self:
        return cls(**dict, **kwargs)

I find that loading a batch size of 32 stereo channel 220,500 frames on a cpu, which is equivalent to 5s of audio takes very long, roughly about 30s to 1min.

My questions are:

  1. How can I speed up the dataloading process?
  2. Would storing the audio as numpy arrays instead of wav files help?

ps: I am new to the audio domain. If you spot any anti-patterns or things that I might be doing inefficiently kindly correct me.