About stft and istft

class InitialReconstruction(BaseModule):
def init(self, n_fft, hop_size):
super(InitialReconstruction, self).init()
self.n_fft = n_fft
self.hop_size = hop_size
window = torch.hann_window(n_fft).float()
self.register_buffer(“window”, window)

def forward(self, stftm):
    angle = torch.angle(stftm)
    magnitudes = torch.abs(stftm)
    stft = magnitudes * torch.exp(1j * angle)
    istft = torch.istft(stft, n_fft=self.n_fft, 
                    hop_length=self.hop_size, win_length=self.n_fft, 
                    window=self.window, center=True,return_complex=False,onesided=False)

return istft.unsqueeze(1)

class FastGL(BaseModule):
def init(self, n_mels, sampling_rate, n_fft, hop_size, momentum=0.99):
super(FastGL, self).init()
self.n_mels = n_mels
self.sampling_rate = sampling_rate
self.n_fft = n_fft
self.hop_size = hop_size
self.momentum = momentum
self.pi = PseudoInversion(n_mels, sampling_rate, n_fft)
self.ir = InitialReconstruction(n_fft, hop_size)
window = torch.hann_window(n_fft).float()
self.register_buffer(“window”, window)

def forward(self, s, n_iters=32):
    c = self.pi(s)
    x = self.ir(c)
    x = x.squeeze(1)
    c = c.unsqueeze(-1)
    prev_angles = torch.angle(s).clone().detach()
    for _ in range(n_iters):        
        s_complex = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_size, 
                       win_length=self.n_fft, window=self.window, 
                       center=True, return_complex=False, onesided=True)
       magnitudes = torch.abs(s_complex)
        angles = torch.angle(s_complex)
        prev_angles = angles
       s = c * (magnitudes*torch.exp(1j*(angles + self.momentum * (angles - prev_angles))))
        x = torch.istft(s, n_fft=self.n_fft, hop_length=self.hop_size, 
                                       win_length=self.n_fft, window=self.window, 
                                       center=True,return_complex=False, onesided=False)

return x.unsqueeze(1)

from these snippets, errors like “RuntimeError: Cannot have onesided output if window or input is complex” or “RuntimeError: istft(CUDAComplexFloatType[1, 513, 602], n_fft=1024, hop_length=256, win_length=1024, window=torch.cuda.FloatTensor{[1024]}, center=1, normalized=0, onesided=0, length=None, return_complex=1) : expected the frequency dimension (3rd to the last) of the input tensor to match n_fft when onesided=False, but got 513(when istft return complex = True)” still keeps occuring. what is the reason?