Permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions

Since updating to PyTorch 2.0.0 i had to add return_complex=True to my code but now it causes this error
“RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 3 is not equal to len(dims) = 4”

Please does anyone know what code i need to change and what to change it to? (without going to an older version of PyTorch).

Code:

dim_s = 4

class STFT:
def init(self, n_fft, hop_length, dim_f):
self.n_fft = n_fft
self.hop_length = hop_length
self.window = torch.hann_window(window_length=n_fft, periodic=True)
self.dim_f = dim_f

def __call__(self, x):
    window = self.window.to(x.device)
    batch_dims = x.shape[:-2]
    c, t = x.shape[-2:]
    x = x.reshape([-1, t])
    x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True, return_complex=True)
    x = x.permute([0,3,1,2])
    x = x.reshape([*batch_dims,c,2,-1,x.shape[-1]]).reshape([*batch_dims,c*2,-1,x.shape[-1]])
    return x[...,:self.dim_f,:]

def inverse(self, x):
    window = self.window.to(x.device)
    batch_dims = x.shape[:-3]
    c,f,t = x.shape[-3:]
    n = self.n_fft//2+1
    f_pad = torch.zeros([*batch_dims,c,n-f,t]).to(x.device)
    x = torch.cat([x, f_pad], -2)
    x = x.reshape([*batch_dims,c//2,2,n,t]).reshape([-1,2,n,t])
    x = x.permute([0,2,3,1])
    x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True, return_complex=True)
    x = x.reshape([*batch_dims,2,-1])
    return x

I can reproduce the mentioned error using:

module = STFT(10, 10, 10)
x = torch.randn(1, 10, 10)

Based on the docs it seems you could restore the 4-dimensional tensor via torch.view_as_real which also seems to fix your error:

    def __call__(self, x):
        window = self.window.to(x.device)
        batch_dims = x.shape[:-2]
        c, t = x.shape[-2:]
        x = x.reshape([-1, t])
        x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True, return_complex=True)
        print(x.shape)
        x = torch.view_as_real(x)
        print(x.shape)
        x = x.permute([0,3,1,2])
        x = x.reshape([*batch_dims,c,2,-1,x.shape[-1]]).reshape([*batch_dims,c*2,-1,x.shape[-1]])
        return x[...,:self.dim_f,:]