Having permute error when converting model from Keras to PyTorch

I am trying to reproduce a model written in Keras in a PyTorch framework but I cannot resolve an error when I use permute to achieve the same effect as layers.transpose in Keras.
Here’s the original code:

class SpectralTransform(layers.Layer):

def __init__(self, c, size=3):
    super().__init__()
    self.conv1 = keras.Sequential([
        layers.Conv1D(c*2, size, 1, padding='same'),
        layers.BatchNormalization(),
        layers.Activation('relu'),
    ])

def call(self, y):
    x = tf.transpose(y, (0, 2, 1))
    x = tf.signal.rfft(x)
    r, i = tf.math.real(x), tf.math.imag(x)
    x = tf.concat([r, i], axis=-2)
    x = tf.transpose(x, (0, 2, 1))
    x = self.conv1(x)
    x = tf.transpose(x, (0, 2, 1))
    r, i = tf.split(x, 2, axis=-2)
    x = tf.complex(r, i)
    x = tf.signal.irfft(x)
    x = tf.transpose(x, (0, 2, 1))
    return layers.add([x,y])

class M_1(keras.Model):

def __init__(self):
    super().__init__()
    self.a = keras.Sequential([
        layers.Reshape((1350, -1)), 
        layers.Conv1D(64, 3, 3),
    ])
    self.ST1 = SpectralTransform(64, 5)
    self.conv1_1 = layers.Conv1D(64, 10, 1, padding='same')
    self.atv1 = keras.Sequential([layers.BatchNormalization(), layers.Activation('relu')])
    self.ST2 = SpectralTransform(64, 3)
    self.conv2 = layers.Conv1D(32, 5, 1, padding='same')
    self.atv2 = keras.Sequential([layers.BatchNormalization(), layers.Activation('relu')])
    self.z = keras.Sequential([
        layers.Conv1D(1, 1, 1),
        layers.Reshape((-1,)),
    ])
    self.w = {}

def call(self, x):
    x = (x - tf.reshape(tf.reduce_mean(x, axis=(1, 2, 3)), (-1, 1, 1, 1, 3)))/tf.reshape(tf.math.reduce_std(x, axis=(1, 2, 3)), (-1, 1, 1, 1, 3))
    
    x = tf.transpose(x, (0, 1, 4, 2, 3))
    x = self.a(x)
    x = self.ST1(x)
    x = self.conv1_1(x)
    x = self.atv1(x)
    x = self.ST2(x)
    x = self.conv2(x)
    x = self.atv2(x)
    return self.z(x)

My PyTorch code is like this:

class SpectralTransform(nn.Module):

def __init__(self, c, size=3):
    super(SpectralTransform, self).__init__()
    self.conv1 = nn.Sequential(
        nn.Conv1d(in_channels=c, out_channels=c*2, kernel_size=size, stride=1, padding=size//2),
        nn.BatchNorm1d(c*2),
        nn.ReLU()
    )

def forward(self, y):
    x = torch.transpose(y, 1, 2)
    x = fft.rfft(x)
    r, i = torch.real(x), torch.imag(x)
    x = torch.cat([r, i], dim=-2)
    x = torch.transpose(x, 1, 2)
    x = self.conv1(x)
    x = torch.transpose(x, 1, 2)
    r, i = torch.split(x, x.size(-2) // 2, dim=-2)
    x = torch.complex(r, i)
    x = fft.irfft(x)
    x = torch.transpose(y, 1, 2)
    return x + y

class Seq_rPPG(nn.Module):

def __init__(self):
    super(Seq_rPPG, self).__init__()
    self.a = nn.Sequential(
        nn.Flatten(),
        nn.Conv1d(in_channels=1, out_channels=64, kernel_size=3, stride=3)
    )
    self.ST1 = SpectralTransform(64, 5)
    self.conv1_1 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=10, stride=1, padding=10//2)
    self.atv1 = nn.Sequential(
        nn.BatchNorm1d(64), 
        nn.ReLU()
    )
    self.ST2 = SpectralTransform(64, 3)
    self.conv2 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=5//2)
    self.atv2 = nn.Sequential(
        nn.BatchNorm1d(32), 
        nn.ReLU()
    )
    self.z = nn.Sequential(
        nn.Conv1d(in_channels=1, out_channels=1, kernel_size=1, stride=1),
        nn.Flatten(start_dim=0, end_dim=-1)
    )
    self.w = {}

def forward(self, x):
    x = (x - torch.reshape(torch.mean(x, dim=(1, 2, 3)), (-1, 1, 1, 1))) / torch.reshape(torch.std(x, dim=(1, 2, 3)), (-1, 1, 1, 1))
    x = x.permute(0, 1, 4, 2, 3)
    x = self.a(x)
    x = self.ST1(x)
    x = self.conv1_1(x)
    x = self.atv1(x)
    x = self.ST2(x)
    x = self.conv2(x)
    x = self.atv2(x)
    return self.z(x)

To be more specific, I got the error at the “x = x.permute(0, 1, 4, 2, 3)” line in the forward method, the error message is “RuntimeError: number of dims don’t match in permute”. Could anyone please take a look? Any help is appreciated!