Convert CNN from Keras to Pytorch

Can anyone help me convert this Keras implementation to Pytorch? For an input image size of 256x256

n_filters = 24 

def CNN(n_outputs=1):
  Conv2D = functools.partial(tf.keras.layers.Conv2D, padding='same', activation='relu')
  BatchNormalization = tf.keras.layers.BatchNormalization
  Flatten = tf.keras.layers.Flatten
  Dense = functools.partial(tf.keras.layers.Dense, activation='relu')

  model = tf.keras.Sequential([
    Conv2D(filters=1*n_filters, kernel_size=5,  strides=2),
    Conv2D(filters=2*n_filters, kernel_size=5,  strides=2),

    Conv2D(filters=4*n_filters, kernel_size=3,  strides=2),

    Conv2D(filters=6*n_filters, kernel_size=3,  strides=2),

    Dense(n_outputs, activation=None),
  return model

This was my attempt:

c = 24
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv0 = nn.Conv2d(in_channels=3, out_channels=c, kernel_size=5, stride=2, padding=1)
        self.conv1 = nn.Conv2d(in_channels=c, out_channels=2*c, kernel_size=5, stride=2, padding=1) # out: c x 14 x 14
        self.conv2 = nn.Conv2d(in_channels=2*c, out_channels=4*c, kernel_size=3, stride=2, padding=1) # out: c x 7 x 7
        self.conv3 = nn.Conv2d(in_channels=4*c, out_channels=6*c, kernel_size=3, stride=2, padding=1) # out: c x 7 x 7
        self.fc_mu = nn.Linear(in_features=36864, out_features=1)
    def forward(self, x):
        x = F.relu(self.conv0(x))
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), 144*16*16) 
        x = self.fc_mu(x)
       return x

I think what I don’t understand here is how to write the Flatten and Linear layer properly, for that 512 input.

Thank you

Here, there should be 2 Linear layers:

self.linear1 = nn.Linear(in_features=36864, out_features=512)
self.linear2 = nn.Linear(in_features=512, out_features=1)


x = x.view(x.size(0), 144*16*16) # Flatten(),
x = F.relu(self.linear1(x)) #Dense(512),
x = self.linear2(x) #Dense(n_outputs, activation=None)

Another thing is adding BatchNorm:

self.batch_norm_0 = nn.BatchNorm2d(c)

And in forward:

x = F.relu(self.conv0(x))
x = batch_norm_0(x)