The spectrogram is a nn.Module.

Just allocate it in the gpu when you create the instance.

```
class Spectrogram(torch.nn.Module):
r"""Create a spectrogram from a audio signal.
Args:
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
pad (int, optional): Two sided padding of signal. (Default: ``0``)
window_fn (Callable[..., Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
power (float or None, optional): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc.
If None, then the complex spectrum is returned instead. (Default: ``2``)
normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
center (bool, optional): whether to pad :attr:`waveform` on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
(Default: ``True``)
pad_mode (string, optional): controls the padding method used when
:attr:`center` is ``True``. (Default: ``"reflect"``)
onesided (bool, optional): controls whether to return half of results to
avoid redundancy (Default: ``True``)
return_complex (bool, optional):
Indicates whether the resulting complex-valued Tensor should be represented with
native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype
mimicking complex value with an extra dimension for real and imaginary parts.
(See also ``torch.view_as_real``.)
This argument is only effective when ``power=None``. It is ignored for
cases where ``power`` is a number as in those cases, the returned tensor is
power spectrogram, which is a real-valued tensor.
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
>>> transform = torchaudio.transforms.Spectrogram(n_fft=800)
>>> spectrogram = transform(waveform)
"""
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
def __init__(self,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: Optional[float] = 2.,
normalized: bool = False,
wkwargs: Optional[dict] = None,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
return_complex: bool = True) -> None:
super(Spectrogram, self).__init__()
self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
# number of frequencies due to onesided=True in torch.stft
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window)
self.pad = pad
self.power = power
self.normalized = normalized
self.center = center
self.pad_mode = pad_mode
self.onesided = onesided
self.return_complex = return_complex
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
Tensor: Dimension (..., freq, time), where freq is
``n_fft // 2 + 1`` where ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
"""
return F.spectrogram(
waveform,
self.pad,
self.window,
self.n_fft,
self.hop_length,
self.win_length,
self.power,
self.normalized,
self.center,
self.pad_mode,
self.onesided,
self.return_complex,
)
```

Or just use the functional with ur own window.