Number T frames output of STFT in 0.4.1

Hello everyone,

When updating torch from 0.4.0 to 0.4.1 it seems like the number of frames computed by STFT has changed for a same given signal length. I would like to ask some clarifications on how to relate signal duration and number of STFT frames in the 0.4.1 version.

I use sr=22050 ; n_fft = 2048 ; win_size = 1024 (hann periodic) ; hop_size = 256 (75% overlap)

considering a single time series " data_in " of size " data_len "

I crop the signal to the number of frames up to the last complete window hop:
n_frames = int(np.floor((data_len-win_size)/hop_size))+1
crop_len = win_size+(n_frames-1)*hop_size

STFT = torch.stft(data_in[:crop_len],n_fft,hop_length=hop_size,win_length=win_size,window=hann_window,center=True,pad_mode=‘reflect’,normalized=False,onesided=True)

Then STFT has the shape (1025, n_frames+4) which is the right number of onesided frequency bins but 4 more frames than what I expected since I cropped the signal.

The 4 frames are consistently added to my initial calculation for a dataset of audio files with several different input length so it seems the correct formula could be

n_frames = int(np.floor((data_len-win_size)/hop_size))+5

Could anyone clarify this point please ?
How can I take an input signal length, infer the number of frames that fit in it without padding, cropping the input signal and then compute STFT that yields the expected number of frames ?

Thanks in advance !

1 Like

according to STFT output shape
it should be:
n_frames = ((data_len - (win_size - 1) - 1) / hop_size) + 1
which as well does not match the actual STFT output shape
floor or ceil(n_frames) is still smaller than T as computed by STFT 0.4.1

1 Like

reversing the problem:

n_frames = 127
data_len = ((n_frames-1)*hop_size)+1+(win_size-1) (== 33280)
data_in = torch.zeros(data_len)

STFT_center = torch.stft(data_in,n_fft,hop_length=hop_size,win_length=win_size,window=window_t,center=True,pad_mode=‘reflect’,normalized=False,onesided=True)
— > STFT_center.shape == torch.Size([1025, 131, 2])

STFT_notcenter = torch.stft(data_in_t,n_fft,hop_length=hop_size,win_length=win_size,window=window_t,center=False,pad_mode=‘reflect’,normalized=False,onesided=True)
— > STFT_notcenter == torch.Size([1025, 123, 2])

So in both cases center=True/False there is a difference of ± 4 frames in the STFT output shape … it would be very useful to have some clarifications on that please !

thanks

Hi, since you specified center=True, for each sampled time, a window will be placed so it is centered around that time, and the input will be padded on both sides so that the first and last sample will have sufficient data in the window. Hence you shouldn’t subtracted win_size when calculating n_frames.

See the doc here:

  • If center is True (default), input will be padded on both sides so that the tt-th frame is centered at time t \times \text{hop_length}t×hop_length. Otherwise, the tt-th frame begins at time t \times \text{hop_length}t×hop_length.

This behavior is also consistent with librosa.

1 Like

Thank you for the explanations, in my case I would then rather use center=False.
Given that and the documentation, I am still confused, may you please clarify what is still wrong here please ? sorry if I overly detail but I need to precisely relate frames <—> input signal length

(with sr=22050 ; win_size=1024 ; hop_size=256 ; n_fft=2048)
data_len = 250000
dummy_in = torch.ones(data_len,dtype=torch.float32).to(device)

STFT center=False means each frame starts at (frame_id*hop_length) and ends at (frame_id*hop_length+win_size-1) while signal goes from 0 to data_len-1

n_frames = int(np.floor((data_len-win_size)/hop_size))+1
crop_len = (n_frames-1)*hop_size+win_size
if crop_len+hop_size<=data_len:
print(’!!! error, missing signal windows !!!’)

dummy_in = dummy_in[:crop_len]
data_mag = torch.stft(dummy_in,n_fft,hop_length=hop_size,win_length=win_size,window=window,center=False,pad_mode=‘reflect’,normalized=False,onesided=True)

data_mag = torch.sqrt(torch.pow(data_mag[:,:,0],2)+torch.pow(data_mag[:,:,1],2))
if data_mag.shape[1]!=n_frames:
print('wrong number of STFT frames, expected ‘+str(n_frames)+’ and got '+str(data_mag.shape[1]))

the first check is ok, the signal is cropped to the maximum number of hops allowed without padding == 49856

the second not == wrong number of STFT frames, expected 973 and got 969

so I still have an error of 4 frames and I don’t understand what is not correct in my calculation given the center=False

it would be extremely helpful to get the exact way to relate frames and signal length

thank you in advance !

1 Like

Okay, the calculation was correct but I found out that the behavior was not as expected because I was using twice more FFT bins than the window size.

with n_fft = win_size the above formula seem correct

I really appreciate what Pytorch team is doing and putting together STFT and related DSP tools in the framework is great !

so I would like to point this out in case it may help the further developments (or not, I might be wrong in what I underline)

the fact that n_fft affects the output number of frames is maybe “undesired”
how I learned, the signal slicing is defined by the window size and the hop size (and an optional center argument)
using n_fft>window_size is for increasing the frequency resolution by taking each signal slice and zero-padding it before computing its own fft

maybe there is something I’m missing but I was actually confused because of this …
issue solved for n_fft = win_size though : )

1 Like

Thank you for the input. :slight_smile:

I remember investigating into the n_fft > window_size case pretty deeply when I wrote stft. Oversampling is quite a corner case, so indeed the behavior can be a bit unexpected. Unfortunately I can’t remember the exact reasons I chose to do it this way, but I vaguely recall that this aligns our stft behavior with librosa’s.