Posting for the first time, please tell me if I made a mistake.
Hi, I am working on a speech enhancement problem, with a STFT → modification in the frequency domain → iSTFT workflow.
My problem is, I have only managed to reconstruct the full signal that I passed into torch.stft when using the center=True
option. Using the librosa implementation it seems like it could work. Maybe I am missing something, how I can achieve the same behavior in pytorch.
The following is a dummy setup to transform a sine wave and transform it back.
import torch
import matplotlib.pyplot as plt
import numpy as np
import librosa
n_fft = 32
# Example singal:
signal = (torch.linspace(0, 2*n_fft, 2*n_fft) * 2 * np.pi).sin()
# Parameters
# stft parameters with and without center
centered = {
'n_fft': n_fft,
'hop_length': n_fft // 2,
'win_length': n_fft,
'window': torch.hann_window(n_fft),
'center': True,
'return_complex': True,
}
uncentered = centered.copy()
uncentered['center'] = False
# parameters for librosa stft for comparison
lr_centered = {
'n_fft': n_fft,
'hop_length': n_fft // 2,
'win_length': n_fft,
'window': 'hann',
'center': True,
}
lr_uncentered = lr_centered.copy()
lr_uncentered['center'] = False
# parameters for the istft
i_centered = {
'n_fft': n_fft,
'hop_length': n_fft // 2,
'win_length': n_fft,
'window': torch.hann_window(n_fft),
'center': True,
'return_complex': False,
# 'length': len(signal),
}
i_uncentered = i_centered.copy()
i_uncentered['center'] = False
i_lr_centered = {
'hop_length': n_fft // 2,
'win_length': n_fft,
'window': 'hann',
'center': True,
}
i_lr_uncentered = i_lr_centered.copy()
i_lr_uncentered['center'] = False
# stfts
stft_centered = signal.stft(**centered)
stft_uncentered = signal.stft(**uncentered)
stft_lr_centered = librosa.stft(signal.numpy(), **lr_centered)
stft_lr_uncentered = librosa.stft(signal.numpy(), **lr_uncentered)
# istfts
i_stft_centered = stft_centered.istft(**i_centered)
# i_stft_uncentered = stft_uncentered.istft(**i_uncentered) # ! this causes an error!
# RuntimeError: istft(CPUComplexFloatType[17, 3], n_fft=32, hop_length=16, win_length=32,
# window=torch.FloatTensor{[32]}, center=0, normalized=0, onesided=None, length=None,
# return_complex=0) window overlap add min: 0
i_stft_lr_centered = librosa.istft(stft_lr_centered, **i_lr_centered)
i_stft_lr_uncentered = librosa.istft(stft_lr_uncentered, **i_lr_uncentered)
# I used the centered parameters to see what happens:
i_stft_uncentered = stft_uncentered.istft(**i_centered)
With the code above, I get the following results:
The input Signal:
[Had to delete this picture, because new users can post only one]
The stfts:
[Had to delete this picture, because new users can post only one]
And the reconstructed signals:
As we can see, the librosa implementation can handle both centered and uncentered stfts with complete reconstruction, but the pytorch implementation fails at i_stft_uncentered = stft_uncentered.istft(**i_uncentered)
(see above).
Is this a limitation of the pytorch implementation, or am I missing a way to get the same behavior in pytorch?