class InitialReconstruction(BaseModule):
def init(self, n_fft, hop_size):
super(InitialReconstruction, self).init()
self.n_fft = n_fft
self.hop_size = hop_size
window = torch.hann_window(n_fft).float()
self.register_buffer(“window”, window)
def forward(self, stftm):
angle = torch.angle(stftm)
magnitudes = torch.abs(stftm)
stft = magnitudes * torch.exp(1j * angle)
istft = torch.istft(stft, n_fft=self.n_fft,
hop_length=self.hop_size, win_length=self.n_fft,
window=self.window, center=True,return_complex=False,onesided=False)
return istft.unsqueeze(1)
class FastGL(BaseModule):
def init(self, n_mels, sampling_rate, n_fft, hop_size, momentum=0.99):
super(FastGL, self).init()
self.n_mels = n_mels
self.sampling_rate = sampling_rate
self.n_fft = n_fft
self.hop_size = hop_size
self.momentum = momentum
self.pi = PseudoInversion(n_mels, sampling_rate, n_fft)
self.ir = InitialReconstruction(n_fft, hop_size)
window = torch.hann_window(n_fft).float()
self.register_buffer(“window”, window)
@torch.no_grad()
def forward(self, s, n_iters=32):
c = self.pi(s)
x = self.ir(c)
x = x.squeeze(1)
c = c.unsqueeze(-1)
prev_angles = torch.angle(s).clone().detach()
for _ in range(n_iters):
s_complex = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_size,
win_length=self.n_fft, window=self.window,
center=True, return_complex=False, onesided=True)
magnitudes = torch.abs(s_complex)
angles = torch.angle(s_complex)
prev_angles = angles
s = c * (magnitudes*torch.exp(1j*(angles + self.momentum * (angles - prev_angles))))
x = torch.istft(s, n_fft=self.n_fft, hop_length=self.hop_size,
win_length=self.n_fft, window=self.window,
center=True,return_complex=False, onesided=False)
return x.unsqueeze(1)
from these snippets, errors like “RuntimeError: Cannot have onesided output if window or input is complex” or “RuntimeError: istft(CUDAComplexFloatType[1, 513, 602], n_fft=1024, hop_length=256, win_length=1024, window=torch.cuda.FloatTensor{[1024]}, center=1, normalized=0, onesided=0, length=None, return_complex=1) : expected the frequency dimension (3rd to the last) of the input tensor to match n_fft when onesided=False, but got 513(when istft return complex = True)” still keeps occuring. what is the reason?