I am trying to port a TF inceptionv3 model to pytorch.
Issue: I ran both models in inference mode. While the output of the first conv layer is comparable between the 2 models, the output of the succeeding batch norm layer is not comparable.
Where might the issue be?
Below is a minimal example.
Model file reference in code is available from here: https://figshare.com/articles/Trained_neural_network_models/8312183
Related posts: Convering a batch normalization layer from TF to Pytorch
import tensorflow as tf
from tensorflow.compat import v1 as v1
from collections import OrderedDict
import os
import sys
import pickle
import numpy as np
import torch
from torch import nn
from torchvision import models
from torch.utils.tensorboard import SummaryWriter
import argparse
tf.compat.v1.disable_eager_execution()
dtype='float'
random_input=np.random.rand(1,3,299,299)
# extract weights from TF graph
sess=v1.Session()
v1.keras.backend.set_session(sess)
v1.keras.backend.set_learning_phase(False)
input_mode='channels_first'
v1.keras.backend.set_image_data_format(input_mode)
tf_input_shape=(1,3,299,299)
saver = v1.train.import_meta_graph("model-1.meta")
saver.restore(sess,"model-1")
graph=v1.get_default_graph()
v1.summary.FileWriter(os.path.join("__tb"), sess.graph)
# extract weights from TF graph
true_vars=OrderedDict()
rms_prop_0_vars=OrderedDict()
rms_prop_1_vars=OrderedDict()
for var in v1.global_variables(): # includes the BN
# print(var.name)
var_name_split=var.name.split("/")
if 'RMSProp:0' in var_name_split[-1]:
rms_prop_0_vars[var.name]=sess.run(var)
elif 'RMSProp_1:0' in var_name_split[-1]:
rms_prop_1_vars[var.name]=sess.run(var)
else:
true_vars[var.name]=sess.run(var)
ops=graph.get_operations()
op_names=[op.name for op in ops]
op_output_tensors=[op.values() for op in ops]
# https://github.com/tensorflow/tensorflow/issues/33129
output_tensors_by_name=['conv2d/Conv2D:0',
'batch_normalization/FusedBatchNorm:0',
'activation/Relu:0']
output_tensors=list(map(lambda t_name: graph.get_tensor_by_name(t_name),
output_tensors_by_name))
input_var=random_input
tf_outputs=sess.run(output_tensors, feed_dict={'x:0':input_var})
class NetHelper(nn.Module):
def __init__(self,num_in_channels,num_out_channels,kernel_size,stride):
super(NetHelper, self).__init__()
self.conv = nn.Conv2d(num_in_channels,num_out_channels,kernel_size=kernel_size,
stride=stride,padding=0,bias=False)
self.bn = nn.BatchNorm2d(num_out_channels,eps=0.001)
def forward(self,x):
return nn.functional.relu(self.bn(self.conv(x)),inplace=True)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 1 input image channel, 6 output channels, 3x3 square convolution
# kernel
self.Conv2d_1a_3x3=NetHelper(3,32,(3,3),stride=2)
def forward(self, x):
# Max pooling over a (2, 2) window
x = self.Conv2d_1a_3x3(x)
return x
model=Net()
kernel=true_vars['conv2d/kernel:0']
kernel_transposed=np.transpose(kernel,(3,2,0,1))
model.Conv2d_1a_3x3.conv.weight=torch.nn.Parameter(torch.from_numpy(kernel_transposed),requires_grad=False)
bn_beta=true_vars['batch_normalization/beta:0']
model.Conv2d_1a_3x3.bn.bias=torch.nn.Parameter(torch.from_numpy(bn_beta),requires_grad=False)
bn_mean=true_vars['batch_normalization/moving_mean:0']
model.Conv2d_1a_3x3.bn.running_mean=torch.nn.Parameter(torch.from_numpy(bn_mean),requires_grad=False)
bn_var=true_vars['batch_normalization/moving_variance:0']
model.Conv2d_1a_3x3.bn.running_var=torch.nn.Parameter(torch.from_numpy(bn_var),requires_grad=False)
model=model.float().eval()
def get_activation(name,activation_dict):
def hook(model, input, output):
activation_dict[name] = output.detach()
return hook
conv_bn_layer=model.Conv2d_1a_3x3
conv_bn_layer_activations={}
layer_names=['conv','bn']
for layer_name in layer_names:
getattr(conv_bn_layer,layer_name).register_forward_hook(get_activation(layer_name,conv_bn_layer_activations))
with torch.no_grad():
output=conv_bn_layer(getattr(torch.from_numpy(random_input),dtype)())
conv_bn_layer_activations['activation']=output
diff=abs(tf_outputs[0]-conv_bn_layer_activations['conv'].numpy())
print(np.amin(diff),np.amax(diff))
diff=abs(tf_outputs[1]-conv_bn_layer_activations['bn'].numpy())
print(np.amin(diff),np.amax(diff))
diff=abs(tf_outputs[2]-conv_bn_layer_activations['activation'].numpy())
print(np.amin(diff),np.amax(diff))
print(np.allclose(tf_outputs[0],conv_bn_layer_activations['conv']))
print(np.allclose(tf_outputs[1],conv_bn_layer_activations['bn']))
print(np.allclose(tf_outputs[2],conv_bn_layer_activations['activation']))