Since updating to PyTorch 2.0.0 i had to add return_complex=True to my code but now it causes this error

“RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 3 is not equal to len(dims) = 4”

Please does anyone know what code i need to change and what to change it to? (without going to an older version of PyTorch).

Code:

dim_s = 4

class STFT:

definit(self, n_fft, hop_length, dim_f):

self.n_fft = n_fft

self.hop_length = hop_length

self.window = torch.hann_window(window_length=n_fft, periodic=True)

self.dim_f = dim_f`def __call__(self, x): window = self.window.to(x.device) batch_dims = x.shape[:-2] c, t = x.shape[-2:] x = x.reshape([-1, t]) x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True, return_complex=True) x = x.permute([0,3,1,2]) x = x.reshape([*batch_dims,c,2,-1,x.shape[-1]]).reshape([*batch_dims,c*2,-1,x.shape[-1]]) return x[...,:self.dim_f,:] def inverse(self, x): window = self.window.to(x.device) batch_dims = x.shape[:-3] c,f,t = x.shape[-3:] n = self.n_fft//2+1 f_pad = torch.zeros([*batch_dims,c,n-f,t]).to(x.device) x = torch.cat([x, f_pad], -2) x = x.reshape([*batch_dims,c//2,2,n,t]).reshape([-1,2,n,t]) x = x.permute([0,2,3,1]) x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True, return_complex=True) x = x.reshape([*batch_dims,2,-1]) return x`