I trained a CNN model (one conv layer with maxpool(2,2) and one dense layer) on MNIST using Tensorflow Estimator with layers defined by tf.keras.layers. The trained Estimator is called mnist_classifier
The mnist_classifer’s structure is defined as:
def cnn_model_fn(features, labels, mode):
"""Model function for a CNN."""
input_layer = tf.reshape(features['x'], [-1, 28, 28, 1])
y = tf.keras.layers.Conv2D(20, 5, strides=1, padding='valid', activation='relu').apply(input_layer)
y = tf.keras.layers.MaxPool2D(2, 2).apply(y)
y = tf.keras.layers.Flatten().apply(y)
logits = tf.keras.layers.Dense(10).apply(y)
I extracted its weights for each layer as:
conv2d_kernel = mnist_classifier.get_variable_value('conv2d/kernel')
conv2d_bias = mnist_classifier.get_variable_value('conv2d/bias')
dense_kernel = mnist_classifier.get_variable_value('dense/kernel')
dense_bias = mnist_classifier.get_variable_value('dense/bias')
I reshape conv2d_kernel and dense_kernel as:
conv2d_kernel_reshape = conv2d_kernel.transpose(3, 2, 0, 1)
dense_kernel_reshape = dense_kernel.transpose()
and use them to initialize the pytorch model with the same structure as mnist_classifier:
class Net(nn.Module):
""" define the same CNN model as mnist_classifier """
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv1.weight.data = torch.Tensor(conv2d_kernel_reshape)
self.conv1.bias.data = torch.Tensor(conv2d_bias)
self.fc1 = nn.Linear(12 * 12 * 20, 10)
self.fc1.weight.data = torch.Tensor(dense_kernel_reshape)
self.fc1.bias.data = torch.Tensor(dense_bias)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 12 * 12 * 20)
x = self.fc1(x)
return x
I initialized the pytorch model’s weights with the weights extracted from the trained Estimator. However, when I test this pytorch model on MNIST test dataset, the test accuracy is around 10%, while the trained mnist_classifier has test accuracy >96%. I am sure that the weights are transferred correctly so I am very confused why the pytorch model with pre-trained weights does not work. Does anyone know what’s the problem here? Thanks a lot in advance.