Stft/istft length mismatch

Hi everyone, I use the model to process the signal and then use istft to convert it back to a waveform, the size of the signal before istft is the same, but the length of the signal after conversion does not match with the original signal. I read the discussion https://github.com/pytorch/pytorch/issues/62747 on github, and it seems that the length problem has been solved, but I still have the problem. Here is my code, forgive me if the layout is a bit messy.

def main():

sample_rate = 44100
device = torch.device('cuda:0')
model_path='/data2/Will/DNN-based_source_separation-main/DNN-based_source_separation-main/src/TEST/best.pth'
model = MMDenseNet.build_model(model_path) ## the model
config = torch.load(model_path, map_location=lambda storage, loc: storage) ### the checkpoint (dict)
model.load_state_dict(config['state_dict'])
model = model.to(device)
model.eval()
channel = model.in_channels
n_fft = 2048
hop = 1024
window_fn = torch.hann_window(n_fft, periodic=True,device=device)  
if os.path.isdir('TT'):
    print('exist')
    ABS_path = os.path.join(os.getcwd(),'TT')
else:
    os.mkdir('TT')

for name in names:
    mixture_path = os.path.join(musdb18_root,'test',name,"mixture.wav")
    with torch.no_grad():
        source, sr = torchaudio.load(mixture_path)
        source = source.to(device)  
        source_duration = source.size(1)/44100      
        source_stft = torch.stft(source, n_fft=n_fft, hop_length=hop,window=window_fn,return_complex=True)
        source_stft = torch.unsqueeze(source_stft, dim=0)
        print(source_stft.shape)
        source_stft_amp = torch.abs(source_stft)
        source_stft_angle = torch.angle(source_stft)
        estimated_amp = model(source_stft_amp)
        print(estimated_amp.shape)
        estimated = estimated_amp*torch.exp(1j*source_stft_angle)
        channels = estimated.size()[:-2] ## keep the B,C
        estimated = estimated.view(-1, *estimated.size()[-2:])
        print(estimated.shape)
        estimated_out = torch.istft(estimated, n_fft=n_fft, hop_length=hop, window=window_fn, return_complex=False)
        estimated_out = estimated_out.view(*channels, -1).squeeze(0).cpu()
        print('TEST',estimated_out.shape,source.shape)
    est_path = os.path.join(ABS_path,'{}.wav'.format(name))
    torchaudio.save(est_path,estimated_out,sample_rate=sample_rate,channels_first=True,bits_per_sample=16)

Is it because of my pytorch version? My version of pytorch is 1.12.1.

1 Like

Are you able to reproduce the expected (and fixed) output shape using the code snippet from the linked issue?

stft_feat = torch.randn(1, 257, 469).to(torch.complex64)
print("stft_feat", stft_feat.size())

n_fft = 512
hop_length = 320
win_length = 512
window = torch.hann_window(win_length)
center = True
length = 150079
reconstruct = torch.istft(stft_feat, n_fft, hop_length=hop_length, win_length=win_length, window=window, center=center, length=length)[0]
print("reconstruct", reconstruct.size())

It works for me using torch==2.4.0.dev20240506+cu124 and I see the expected: reconstruct torch.Size([150079]).

Yes I can execute it successfully too, the output is the same just with a warning:

stft_feat torch.Size([1, 257, 469])
/data2/Will/DNN-based_source_separation-main/DNN-based_source_separation-main/src/real_time_test/TT.py:12: UserWarning: The length of signal is shorter than the length parameter. Result is being padded with zeros in the tail. Please check your center and hop_length settings. (Triggered internally at  /opt/conda/conda-bld/pytorch_1659484809662/work/aten/src/ATen/native/SpectralOps.cpp:1108.)
  reconstruct = torch.istft(stft_feat, n_fft, hop_length=hop_length, win_length=win_length, window=window, center=center, length=length)[0]
reconstruct torch.Size([150079])

But when I run my own program, I can’t successfully restore the original length, it’s slightly smaller than the original signal length.
source_stft.shape : torch.Size([1, 2, 1025, 7735])
model output : torch.Size([1, 2, 1025, 7735])
reshape model output before istft : torch.Size([2, 1025, 7735])
after istft : torch.Size([2, 7919616])
source waveform shape : torch.Size([2, 7920084])

I ran the same program on another workstation and it was able to produce the original signal length. torch version is 2.0.1+cu118.