Unable to move window to GPU for torchaudio.Spectrogram

Getting the error:

RuntimeError: stft input and window must be on the same device but got self on cuda:0 and window on cpu

The problem seems to be that the window is not on the GPU. I can move it to the GPU by creating the window as a tensor.

window = torch.hamming_window(window_length = window_len, device = "cuda")

But then the problem is that the Spectrogram function only seems to accept the window as a function and not a tensor, at least from my testing/understanding.

Here’s a reproducible example, tested on torch 1.12.1 and torchaudio 0.12.1.

import torch
import torchaudio

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

window_len = 256
fft_len = 256
overlap = window_len//4
window = torch.hamming_window

waveform = torch.randn(1, 16000*2).to(device)

specgram = torchaudio.transforms.Spectrogram(n_fft=fft_len, win_length=window_len, hop_length=overlap, window_fn = window)(waveform).to(device)

Any help is greatly appreciated!

The spectrogram is a nn.Module.
Just allocate it in the gpu when you create the instance.

class Spectrogram(torch.nn.Module):
    r"""Create a spectrogram from a audio signal.

        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
        power (float or None, optional): Exponent for the magnitude spectrogram,
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
            If None, then the complex spectrum is returned instead. (Default: ``2``)
        normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
            (Default: ``True``)
        pad_mode (string, optional): controls the padding method used when
            :attr:`center` is ``True``. (Default: ``"reflect"``)
        onesided (bool, optional): controls whether to return half of results to
            avoid redundancy (Default: ``True``)
        return_complex (bool, optional):
            Indicates whether the resulting complex-valued Tensor should be represented with
            native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype
            mimicking complex value with an extra dimension for real and imaginary parts.
            (See also ``torch.view_as_real``.)
            This argument is only effective when ``power=None``. It is ignored for
            cases where ``power`` is a number as in those cases, the returned tensor is
            power spectrogram, which is a real-valued tensor.

        >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
        >>> transform = torchaudio.transforms.Spectrogram(n_fft=800)
        >>> spectrogram = transform(waveform)

    __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']

    def __init__(self,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 pad: int = 0,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: Optional[float] = 2.,
                 normalized: bool = False,
                 wkwargs: Optional[dict] = None,
                 center: bool = True,
                 pad_mode: str = "reflect",
                 onesided: bool = True,
                 return_complex: bool = True) -> None:
        super(Spectrogram, self).__init__()
        self.n_fft = n_fft
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
        # number of frequencies due to onesided=True in torch.stft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
        self.register_buffer('window', window)
        self.pad = pad
        self.power = power
        self.normalized = normalized
        self.center = center
        self.pad_mode = pad_mode
        self.onesided = onesided
        self.return_complex = return_complex

    def forward(self, waveform: Tensor) -> Tensor:
            waveform (Tensor): Tensor of audio of dimension (..., time).

            Tensor: Dimension (..., freq, time), where freq is
            ``n_fft // 2 + 1`` where ``n_fft`` is the number of
            Fourier bins, and time is the number of window hops (n_frame).
        return F.spectrogram(

Or just use the functional with ur own window.

Great, that fixed it. Thanks a lot!