PyTorch STFT different from overlapped FFT

Hi folks,

I am currently having some issues translating some code to work on real time. Basically, I am doing a STFT/iSTFT in offline mode, that I need to replace with FFT/iFFT in real time. However, I am finding some apparent differences between torch.stft and torch.istft compared to torch.fft.rfft and torch.fft.irfft that I can’t still figure out where they come from.

In the following code, I am performing a stft with fft_size of 256 and hop_size of 128. First, I’m using torch.stft and torch.istft to process a tensor comprised of two frames, and then I am doing the same by separately for each frame and them I’m trying to assemble then doing manual overlap-add. Somehow the results are not exactly the same even using the exact same parameters, so here come my questions:

  1. Why torch.istft reconstructs the signal with 1 sample less than the original? The original signal has exactly 384 (256+128) samples, therefore I expected exactly 2 frames of 256 with 128 overlap. If I add the length parameter to torch.istft I seem to have the expected behavior.

  2. If I compare each individual frame’s fft I get close enough results when compared with torch.stft in such a way that torch.allclose() returns True. However, if I use torch.istft to do overlap-add and then I compared it with overlapping both frames manually, I get the exact same results in the non-overlapping regions, however the overlap does not seem to be equal. Why is it happening? I am using no windowing to simplify the code.

Thanks for your help!

Code:

import torch

fft_size = 256
window_size = fft_size
hop_size = 128
normalized = False
onesided = True
center = False


if __name__ == "__main__":
    torch.manual_seed(0)

    # Simulate signal of two frames of fft_size with hop_size overlap with
    # dimensions (batch_size, n_samples)
    x_frames = torch.randn((1, fft_size + hop_size),
                           dtype=torch.float32)

    # Apply stft
    x_stft = torch.stft(x_frames, n_fft=fft_size, hop_length=hop_size,
                        center=center, normalized=normalized,
                        onesided=onesided, return_complex=True)

    # Apply istft
    x_istft = torch.istft(x_stft, n_fft=fft_size, hop_length=hop_size,
                          center=center, normalized=normalized,
                          onesided=onesided)

    x_istft_with_length = torch.istft(x_stft, n_fft=fft_size,
                                      hop_length=hop_size, center=center,
                                      normalized=normalized, onesided=onesided,
                                      length=x_frames.size(-1))

    # Check reconstruction
    is_rec_allclose = torch.allclose(x_frames, x_istft_with_length)
    print(f"Are original and reconstructed signal close?: {is_rec_allclose}")

    # Now let's do individual frame reconstruction
    x_frames_0 = x_frames[..., :fft_size]
    x_frames_1 = x_frames[..., hop_size:]

    # Apply fft
    x_frames_0_fft = torch.fft.rfft(x_frames_0, dim=-1)
    x_frames_1_fft = torch.fft.rfft(x_frames_1, dim=-1)

    # Compare individual frames with the ones transformed with stft
    is_frame_0_allclose = torch.allclose(x_stft[..., 0], x_frames_0_fft)
    print(f"Is frame_0 fft close to frame_0 stft?: {is_frame_0_allclose}")

    is_frame_1_allclose = torch.allclose(x_stft[..., 1], x_frames_1_fft)
    print(f"Is frame_1 fft close to frame_1 stft?: {is_frame_1_allclose}")

    # Apply ifft
    x_frames_0_ifft = torch.fft.irfft(x_frames_0_fft, dim=-1)
    x_frames_1_ifft = torch.fft.irfft(x_frames_1_fft, dim=-1)

    # Compare non-overlapping part of individual frames with the ones
    # transformed with istft
    is_frame_0_ifft_allclose = torch.allclose(
        x_istft_with_length[..., :hop_size], x_frames_0_ifft[..., :hop_size])
    print(f"Is frame_0 ifft close to frame_0 istft?: {is_frame_0_ifft_allclose}")

    is_frame_1_ifft_allclose = torch.allclose(
        x_istft_with_length[..., -hop_size:], x_frames_1_ifft[..., -hop_size:])
    print(f"Is frame_1 ifft close to frame_1 istft?: {is_frame_1_ifft_allclose}")

    # Now let's do a manual overlap and see if the resulting signal is similar
    x_istft_from_frames = torch.zeros_like(x_frames, dtype=torch.float32)
    x_istft_from_frames[..., :fft_size] = x_frames_0_ifft
    x_istft_from_frames[..., hop_size:] += x_frames_1_ifft

    is_istft_from_frames_allclose = torch.allclose(x_istft_from_frames,
                                                   x_istft_with_length)
    print(f"Are both reconstructions close?: {is_istft_from_frames_allclose}")

    # If not, let's look closer with an element-wise comparison
    print(x_istft_from_frames == x_istft_with_length)

    # Let's see the numerical differences:
    print(x_istft_from_frames - x_istft_with_length)

Output:

Are original and reconstructed signal close?: False
Is frame_0 fft close to frame_0 stft?: True
Is frame_1 fft close to frame_1 stft?: True
Is frame_0 ifft close to frame_0 istft?: True
Is frame_1 ifft close to frame_1 istft?: True
Are both reconstructions close?: False
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True]])
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         -0.8834, -0.4189, -0.8048,  0.5656,  0.6104,  0.4669,  1.9507, -1.0631,
         -0.0773,  0.1164, -0.5940, -1.2439, -0.1021, -1.0335, -0.3126,  0.2458,
         -0.2596,  0.1183,  0.2440,  1.1646,  0.2886,  0.3866, -0.2011, -0.1179,
          0.1922, -0.7722, -1.9003,  0.1307, -0.7043,  0.3147,  0.1574,  0.3854,
          0.9671, -0.9911,  0.3016, -0.1073,  0.9985, -0.4987,  0.7611,  0.6183,
          0.3140,  0.2133, -0.1201,  0.3605, -0.3140, -1.0787,  0.2408, -1.3962,
         -0.0661, -0.3584, -1.5616, -0.3546,  1.0811,  0.1315,  1.5735,  0.7814,
         -1.0787, -0.7209,  1.4708,  0.2756,  0.6668, -0.9944, -1.1894, -1.1959,
         -0.5596,  0.5335,  0.4069,  0.3946,  0.1715,  0.8760, -0.2871,  1.0216,
         -0.0744, -1.0922,  0.3920,  0.5945,  0.6623, -1.2063,  0.6074, -0.5472,
          1.1711,  0.0975,  0.9634,  0.8403, -1.2537,  0.9868, -0.4947, -1.2830,
          0.9552,  1.2836, -0.6659,  0.5651,  0.2877, -0.0334, -1.0619, -0.1144,
         -0.3433,  1.5713,  0.1916,  0.3799, -0.1448,  0.6376, -0.2813, -1.3299,
         -0.1420, -0.5341, -0.5234,  0.8615, -0.8870,  0.8388,  1.1529, -1.7611,
         -1.4777, -1.7557,  0.0762, -1.0786,  1.4403, -0.1106,  0.5769, -0.1692,
         -0.0640,  1.0384,  0.9068, -0.4755, -0.8707,  0.1447,  1.9029,  0.3904,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]])