I am trying to reproduce a model written in Keras in a PyTorch framework but I cannot resolve an error when I use permute to achieve the same effect as layers.transpose in Keras.
Here’s the original code:
class SpectralTransform(layers.Layer):
def __init__(self, c, size=3):
super().__init__()
self.conv1 = keras.Sequential([
layers.Conv1D(c*2, size, 1, padding='same'),
layers.BatchNormalization(),
layers.Activation('relu'),
])
def call(self, y):
x = tf.transpose(y, (0, 2, 1))
x = tf.signal.rfft(x)
r, i = tf.math.real(x), tf.math.imag(x)
x = tf.concat([r, i], axis=-2)
x = tf.transpose(x, (0, 2, 1))
x = self.conv1(x)
x = tf.transpose(x, (0, 2, 1))
r, i = tf.split(x, 2, axis=-2)
x = tf.complex(r, i)
x = tf.signal.irfft(x)
x = tf.transpose(x, (0, 2, 1))
return layers.add([x,y])
class M_1(keras.Model):
def __init__(self):
super().__init__()
self.a = keras.Sequential([
layers.Reshape((1350, -1)),
layers.Conv1D(64, 3, 3),
])
self.ST1 = SpectralTransform(64, 5)
self.conv1_1 = layers.Conv1D(64, 10, 1, padding='same')
self.atv1 = keras.Sequential([layers.BatchNormalization(), layers.Activation('relu')])
self.ST2 = SpectralTransform(64, 3)
self.conv2 = layers.Conv1D(32, 5, 1, padding='same')
self.atv2 = keras.Sequential([layers.BatchNormalization(), layers.Activation('relu')])
self.z = keras.Sequential([
layers.Conv1D(1, 1, 1),
layers.Reshape((-1,)),
])
self.w = {}
def call(self, x):
x = (x - tf.reshape(tf.reduce_mean(x, axis=(1, 2, 3)), (-1, 1, 1, 1, 3)))/tf.reshape(tf.math.reduce_std(x, axis=(1, 2, 3)), (-1, 1, 1, 1, 3))
x = tf.transpose(x, (0, 1, 4, 2, 3))
x = self.a(x)
x = self.ST1(x)
x = self.conv1_1(x)
x = self.atv1(x)
x = self.ST2(x)
x = self.conv2(x)
x = self.atv2(x)
return self.z(x)
My PyTorch code is like this:
class SpectralTransform(nn.Module):
def __init__(self, c, size=3):
super(SpectralTransform, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv1d(in_channels=c, out_channels=c*2, kernel_size=size, stride=1, padding=size//2),
nn.BatchNorm1d(c*2),
nn.ReLU()
)
def forward(self, y):
x = torch.transpose(y, 1, 2)
x = fft.rfft(x)
r, i = torch.real(x), torch.imag(x)
x = torch.cat([r, i], dim=-2)
x = torch.transpose(x, 1, 2)
x = self.conv1(x)
x = torch.transpose(x, 1, 2)
r, i = torch.split(x, x.size(-2) // 2, dim=-2)
x = torch.complex(r, i)
x = fft.irfft(x)
x = torch.transpose(y, 1, 2)
return x + y
class Seq_rPPG(nn.Module):
def __init__(self):
super(Seq_rPPG, self).__init__()
self.a = nn.Sequential(
nn.Flatten(),
nn.Conv1d(in_channels=1, out_channels=64, kernel_size=3, stride=3)
)
self.ST1 = SpectralTransform(64, 5)
self.conv1_1 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=10, stride=1, padding=10//2)
self.atv1 = nn.Sequential(
nn.BatchNorm1d(64),
nn.ReLU()
)
self.ST2 = SpectralTransform(64, 3)
self.conv2 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=5//2)
self.atv2 = nn.Sequential(
nn.BatchNorm1d(32),
nn.ReLU()
)
self.z = nn.Sequential(
nn.Conv1d(in_channels=1, out_channels=1, kernel_size=1, stride=1),
nn.Flatten(start_dim=0, end_dim=-1)
)
self.w = {}
def forward(self, x):
x = (x - torch.reshape(torch.mean(x, dim=(1, 2, 3)), (-1, 1, 1, 1))) / torch.reshape(torch.std(x, dim=(1, 2, 3)), (-1, 1, 1, 1))
x = x.permute(0, 1, 4, 2, 3)
x = self.a(x)
x = self.ST1(x)
x = self.conv1_1(x)
x = self.atv1(x)
x = self.ST2(x)
x = self.conv2(x)
x = self.atv2(x)
return self.z(x)
To be more specific, I got the error at the “x = x.permute(0, 1, 4, 2, 3)” line in the forward method, the error message is “RuntimeError: number of dims don’t match in permute”. Could anyone please take a look? Any help is appreciated!