Difference in batchnorm outputs when converting from TF model to Pytorch

I am trying to port a TF inceptionv3 model to pytorch.

Issue: I ran both models in inference mode. While the output of the first conv layer is comparable between the 2 models, the output of the succeeding batch norm layer is not comparable.

Where might the issue be?

Below is a minimal example.

Model file reference in code is available from here: https://figshare.com/articles/Trained_neural_network_models/8312183

Related posts: Convering a batch normalization layer from TF to Pytorch

import tensorflow as tf
from tensorflow.compat import v1 as v1

from collections import OrderedDict

import os
import sys
import pickle

import numpy as np

import torch
from torch import nn
from torchvision import models
from torch.utils.tensorboard import SummaryWriter

import argparse


tf.compat.v1.disable_eager_execution()

dtype='float'
random_input=np.random.rand(1,3,299,299)

# extract weights from TF graph
sess=v1.Session()
v1.keras.backend.set_session(sess)
v1.keras.backend.set_learning_phase(False)

input_mode='channels_first'
v1.keras.backend.set_image_data_format(input_mode)

tf_input_shape=(1,3,299,299)


saver = v1.train.import_meta_graph("model-1.meta")
saver.restore(sess,"model-1")

graph=v1.get_default_graph()

v1.summary.FileWriter(os.path.join("__tb"), sess.graph)


# extract weights from TF graph
true_vars=OrderedDict()
rms_prop_0_vars=OrderedDict()
rms_prop_1_vars=OrderedDict()
for var in v1.global_variables(): # includes the BN
	# print(var.name)
	var_name_split=var.name.split("/")
	if 'RMSProp:0' in var_name_split[-1]:
		rms_prop_0_vars[var.name]=sess.run(var)
	elif 'RMSProp_1:0' in var_name_split[-1]:
		rms_prop_1_vars[var.name]=sess.run(var)
	else:
		true_vars[var.name]=sess.run(var)



ops=graph.get_operations()
op_names=[op.name for op in ops]
op_output_tensors=[op.values() for op in ops]
# https://github.com/tensorflow/tensorflow/issues/33129
output_tensors_by_name=['conv2d/Conv2D:0',
						'batch_normalization/FusedBatchNorm:0',
						'activation/Relu:0']

output_tensors=list(map(lambda t_name: graph.get_tensor_by_name(t_name),
					output_tensors_by_name))

input_var=random_input
tf_outputs=sess.run(output_tensors, feed_dict={'x:0':input_var})



class NetHelper(nn.Module):
	def __init__(self,num_in_channels,num_out_channels,kernel_size,stride):
		super(NetHelper, self).__init__()
		self.conv = nn.Conv2d(num_in_channels,num_out_channels,kernel_size=kernel_size,
								stride=stride,padding=0,bias=False)
		self.bn = nn.BatchNorm2d(num_out_channels,eps=0.001)

	def forward(self,x):
		return nn.functional.relu(self.bn(self.conv(x)),inplace=True)

class Net(nn.Module):

	def __init__(self):
		super(Net, self).__init__()
		# 1 input image channel, 6 output channels, 3x3 square convolution
		# kernel
		self.Conv2d_1a_3x3=NetHelper(3,32,(3,3),stride=2)
		

	def forward(self, x):
		# Max pooling over a (2, 2) window
		
		x = self.Conv2d_1a_3x3(x)

		return x



model=Net()

kernel=true_vars['conv2d/kernel:0']
kernel_transposed=np.transpose(kernel,(3,2,0,1))
model.Conv2d_1a_3x3.conv.weight=torch.nn.Parameter(torch.from_numpy(kernel_transposed),requires_grad=False)
bn_beta=true_vars['batch_normalization/beta:0']
model.Conv2d_1a_3x3.bn.bias=torch.nn.Parameter(torch.from_numpy(bn_beta),requires_grad=False)
bn_mean=true_vars['batch_normalization/moving_mean:0']
model.Conv2d_1a_3x3.bn.running_mean=torch.nn.Parameter(torch.from_numpy(bn_mean),requires_grad=False)
bn_var=true_vars['batch_normalization/moving_variance:0']
model.Conv2d_1a_3x3.bn.running_var=torch.nn.Parameter(torch.from_numpy(bn_var),requires_grad=False)


model=model.float().eval()


def get_activation(name,activation_dict):
	def hook(model, input, output):
		activation_dict[name] = output.detach()
	return hook


conv_bn_layer=model.Conv2d_1a_3x3
conv_bn_layer_activations={}
layer_names=['conv','bn']
for layer_name in layer_names:
	getattr(conv_bn_layer,layer_name).register_forward_hook(get_activation(layer_name,conv_bn_layer_activations))

with torch.no_grad():
	output=conv_bn_layer(getattr(torch.from_numpy(random_input),dtype)())
	conv_bn_layer_activations['activation']=output

diff=abs(tf_outputs[0]-conv_bn_layer_activations['conv'].numpy())
print(np.amin(diff),np.amax(diff))
diff=abs(tf_outputs[1]-conv_bn_layer_activations['bn'].numpy())
print(np.amin(diff),np.amax(diff))
diff=abs(tf_outputs[2]-conv_bn_layer_activations['activation'].numpy())
print(np.amin(diff),np.amax(diff))

print(np.allclose(tf_outputs[0],conv_bn_layer_activations['conv']))
print(np.allclose(tf_outputs[1],conv_bn_layer_activations['bn']))
print(np.allclose(tf_outputs[2],conv_bn_layer_activations['activation']))
1 Like

How large are the absolute errors between the different runs?