Hi folks,
I am currently having some issues translating some code to work on real time. Basically, I am doing a STFT/iSTFT in offline mode, that I need to replace with FFT/iFFT in real time. However, I am finding some apparent differences between torch.stft
and torch.istft
compared to torch.fft.rfft
and torch.fft.irfft
that I can’t still figure out where they come from.
In the following code, I am performing a stft with fft_size
of 256
and hop_size
of 128
. First, I’m using torch.stft
and torch.istft
to process a tensor comprised of two frames, and then I am doing the same by separately for each frame and them I’m trying to assemble then doing manual overlap-add. Somehow the results are not exactly the same even using the exact same parameters, so here come my questions:
-
Why
torch.istft
reconstructs the signal with 1 sample less than the original? The original signal has exactly 384 (256+128) samples, therefore I expected exactly 2 frames of 256 with 128 overlap. If I add the length parameter totorch.istft
I seem to have the expected behavior. -
If I compare each individual frame’s fft I get close enough results when compared with
torch.stft
in such a way thattorch.allclose()
returnsTrue
. However, if I usetorch.istft
to do overlap-add and then I compared it with overlapping both frames manually, I get the exact same results in the non-overlapping regions, however the overlap does not seem to be equal. Why is it happening? I am using no windowing to simplify the code.
Thanks for your help!
Code:
import torch
fft_size = 256
window_size = fft_size
hop_size = 128
normalized = False
onesided = True
center = False
if __name__ == "__main__":
torch.manual_seed(0)
# Simulate signal of two frames of fft_size with hop_size overlap with
# dimensions (batch_size, n_samples)
x_frames = torch.randn((1, fft_size + hop_size),
dtype=torch.float32)
# Apply stft
x_stft = torch.stft(x_frames, n_fft=fft_size, hop_length=hop_size,
center=center, normalized=normalized,
onesided=onesided, return_complex=True)
# Apply istft
x_istft = torch.istft(x_stft, n_fft=fft_size, hop_length=hop_size,
center=center, normalized=normalized,
onesided=onesided)
x_istft_with_length = torch.istft(x_stft, n_fft=fft_size,
hop_length=hop_size, center=center,
normalized=normalized, onesided=onesided,
length=x_frames.size(-1))
# Check reconstruction
is_rec_allclose = torch.allclose(x_frames, x_istft_with_length)
print(f"Are original and reconstructed signal close?: {is_rec_allclose}")
# Now let's do individual frame reconstruction
x_frames_0 = x_frames[..., :fft_size]
x_frames_1 = x_frames[..., hop_size:]
# Apply fft
x_frames_0_fft = torch.fft.rfft(x_frames_0, dim=-1)
x_frames_1_fft = torch.fft.rfft(x_frames_1, dim=-1)
# Compare individual frames with the ones transformed with stft
is_frame_0_allclose = torch.allclose(x_stft[..., 0], x_frames_0_fft)
print(f"Is frame_0 fft close to frame_0 stft?: {is_frame_0_allclose}")
is_frame_1_allclose = torch.allclose(x_stft[..., 1], x_frames_1_fft)
print(f"Is frame_1 fft close to frame_1 stft?: {is_frame_1_allclose}")
# Apply ifft
x_frames_0_ifft = torch.fft.irfft(x_frames_0_fft, dim=-1)
x_frames_1_ifft = torch.fft.irfft(x_frames_1_fft, dim=-1)
# Compare non-overlapping part of individual frames with the ones
# transformed with istft
is_frame_0_ifft_allclose = torch.allclose(
x_istft_with_length[..., :hop_size], x_frames_0_ifft[..., :hop_size])
print(f"Is frame_0 ifft close to frame_0 istft?: {is_frame_0_ifft_allclose}")
is_frame_1_ifft_allclose = torch.allclose(
x_istft_with_length[..., -hop_size:], x_frames_1_ifft[..., -hop_size:])
print(f"Is frame_1 ifft close to frame_1 istft?: {is_frame_1_ifft_allclose}")
# Now let's do a manual overlap and see if the resulting signal is similar
x_istft_from_frames = torch.zeros_like(x_frames, dtype=torch.float32)
x_istft_from_frames[..., :fft_size] = x_frames_0_ifft
x_istft_from_frames[..., hop_size:] += x_frames_1_ifft
is_istft_from_frames_allclose = torch.allclose(x_istft_from_frames,
x_istft_with_length)
print(f"Are both reconstructions close?: {is_istft_from_frames_allclose}")
# If not, let's look closer with an element-wise comparison
print(x_istft_from_frames == x_istft_with_length)
# Let's see the numerical differences:
print(x_istft_from_frames - x_istft_with_length)
Output:
Are original and reconstructed signal close?: False
Is frame_0 fft close to frame_0 stft?: True
Is frame_1 fft close to frame_1 stft?: True
Is frame_0 ifft close to frame_0 istft?: True
Is frame_1 ifft close to frame_1 istft?: True
Are both reconstructions close?: False
tensor([[ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True]])
tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
-0.8834, -0.4189, -0.8048, 0.5656, 0.6104, 0.4669, 1.9507, -1.0631,
-0.0773, 0.1164, -0.5940, -1.2439, -0.1021, -1.0335, -0.3126, 0.2458,
-0.2596, 0.1183, 0.2440, 1.1646, 0.2886, 0.3866, -0.2011, -0.1179,
0.1922, -0.7722, -1.9003, 0.1307, -0.7043, 0.3147, 0.1574, 0.3854,
0.9671, -0.9911, 0.3016, -0.1073, 0.9985, -0.4987, 0.7611, 0.6183,
0.3140, 0.2133, -0.1201, 0.3605, -0.3140, -1.0787, 0.2408, -1.3962,
-0.0661, -0.3584, -1.5616, -0.3546, 1.0811, 0.1315, 1.5735, 0.7814,
-1.0787, -0.7209, 1.4708, 0.2756, 0.6668, -0.9944, -1.1894, -1.1959,
-0.5596, 0.5335, 0.4069, 0.3946, 0.1715, 0.8760, -0.2871, 1.0216,
-0.0744, -1.0922, 0.3920, 0.5945, 0.6623, -1.2063, 0.6074, -0.5472,
1.1711, 0.0975, 0.9634, 0.8403, -1.2537, 0.9868, -0.4947, -1.2830,
0.9552, 1.2836, -0.6659, 0.5651, 0.2877, -0.0334, -1.0619, -0.1144,
-0.3433, 1.5713, 0.1916, 0.3799, -0.1448, 0.6376, -0.2813, -1.3299,
-0.1420, -0.5341, -0.5234, 0.8615, -0.8870, 0.8388, 1.1529, -1.7611,
-1.4777, -1.7557, 0.0762, -1.0786, 1.4403, -0.1106, 0.5769, -0.1692,
-0.0640, 1.0384, 0.9068, -0.4755, -0.8707, 0.1447, 1.9029, 0.3904,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])