How to restore the full signal from non-centered stft?

Posting for the first time, please tell me if I made a mistake.

Hi, I am working on a speech enhancement problem, with a STFT → modification in the frequency domain → iSTFT workflow.

My problem is, I have only managed to reconstruct the full signal that I passed into torch.stft when using the center=True option. Using the librosa implementation it seems like it could work. Maybe I am missing something, how I can achieve the same behavior in pytorch.
The following is a dummy setup to transform a sine wave and transform it back.

import torch
import matplotlib.pyplot as plt
import numpy as np
import librosa

n_fft = 32
# Example singal:
signal = (torch.linspace(0, 2*n_fft, 2*n_fft) * 2 * np.pi).sin()

# Parameters

# stft parameters with and without center
centered = {
    'n_fft': n_fft,
    'hop_length': n_fft // 2,
    'win_length': n_fft,
    'window': torch.hann_window(n_fft),
    'center': True,
    'return_complex': True,
}
uncentered = centered.copy()
uncentered['center'] = False

# parameters for librosa stft for comparison
lr_centered =  {
    'n_fft': n_fft,
    'hop_length': n_fft // 2,
    'win_length': n_fft,
    'window': 'hann',
    'center': True,
}

lr_uncentered = lr_centered.copy()
lr_uncentered['center'] = False

# parameters for the istft
i_centered = {
    'n_fft': n_fft,
    'hop_length': n_fft // 2,
    'win_length': n_fft,
    'window': torch.hann_window(n_fft),
    'center': True,
    'return_complex': False,
    # 'length': len(signal),
}
i_uncentered = i_centered.copy()
i_uncentered['center'] = False

i_lr_centered = {
    'hop_length': n_fft // 2,
    'win_length': n_fft,
    'window': 'hann',
    'center': True,  
}
i_lr_uncentered = i_lr_centered.copy()
i_lr_uncentered['center'] = False

# stfts
stft_centered = signal.stft(**centered)
stft_uncentered = signal.stft(**uncentered)
stft_lr_centered = librosa.stft(signal.numpy(), **lr_centered)
stft_lr_uncentered = librosa.stft(signal.numpy(), **lr_uncentered)

# istfts
i_stft_centered = stft_centered.istft(**i_centered)
# i_stft_uncentered = stft_uncentered.istft(**i_uncentered) # ! this causes an error!
# RuntimeError: istft(CPUComplexFloatType[17, 3], n_fft=32, hop_length=16, win_length=32, 
# window=torch.FloatTensor{[32]}, center=0, normalized=0, onesided=None, length=None, 
# return_complex=0) window overlap add min: 0
i_stft_lr_centered = librosa.istft(stft_lr_centered, **i_lr_centered)
i_stft_lr_uncentered = librosa.istft(stft_lr_uncentered, **i_lr_uncentered)

# I used the centered parameters to see what happens:
i_stft_uncentered = stft_uncentered.istft(**i_centered)

With the code above, I get the following results:
The input Signal:

[Had to delete this picture, because new users can post only one]
The stfts:

[Had to delete this picture, because new users can post only one]

And the reconstructed signals:
reconstructed

As we can see, the librosa implementation can handle both centered and uncentered stfts with complete reconstruction, but the pytorch implementation fails at i_stft_uncentered = stft_uncentered.istft(**i_uncentered) (see above).

Is this a limitation of the pytorch implementation, or am I missing a way to get the same behavior in pytorch?

1 Like

The STFTs:
stfts

I hope I am not breaking any rules by circumventing the 1 media element restriction like this

I am also curious about this:

Z = torch.stft(X, n_fft=512, center=False, window=torch.hann_window(512))
X2 = torch.istft(Z, n_fft=512, center=False, window=torch.hann_window(512))

gives:

~/opt/miniconda3/lib/python3.8/site-packages/torch/functional.py in istft(input, n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex)
    652             length=length, return_complex=return_complex)
    653 
--> 654     return _VF.istft(input, n_fft, hop_length, win_length, window, center,  # type: ignore
    655                      normalized, onesided, length, return_complex)
    656 

RuntimeError: istft(torch.DoubleTensor[257, 686, 2], n_fft=512, hop_length=128, win_length=512, window=torch.FloatTensor{[512]}, center=0, normalized=0, onesided=None, length=None, return_complex=0) window overlap add min: 0

Same thing here. Did you figure out what may be the issue?