I have a tensorflow function. I want to convert it to pytorch. However, I got a big error between tensorflow and pytorch result. Could you please check my code and let me know what is the problem?
tensorflow result 0.23463342
Pytorch result 0.035501156
The colab is at Here
This is my target tensorflow code
eps = 1e-5
ndims = 3
win = [9] * ndims
####################################
# TENSORFLOW
####################################
def ncc_tensorflow(I, J):
# get convolution function
conv_fn = getattr(tf.nn, 'conv%dd' % ndims)
# compute CC squares
I2 = I*I
J2 = J*J
IJ = I*J
# compute filters
sum_filt = tf.ones([*win, 1, 1])
padding = 'SAME'
strides = [1] * (ndims + 2)
# compute local sums via convolution
I_sum = conv_fn(I, sum_filt, strides, padding)
J_sum = conv_fn(J, sum_filt, strides, padding)
I2_sum = conv_fn(I2, sum_filt, strides, padding)
J2_sum = conv_fn(J2, sum_filt, strides, padding)
IJ_sum = conv_fn(IJ, sum_filt, strides, padding)
# compute cross correlation
win_size = np.prod(win)
u_I = I_sum/win_size
u_J = J_sum/win_size
cross = IJ_sum - u_J*I_sum - u_I*J_sum + u_I*u_J*win_size
I_var = I2_sum - 2 * u_I * I_sum + u_I*u_I*win_size
J_var = J2_sum - 2 * u_J * J_sum + u_J*u_J*win_size
cc = cross*cross / (I_var*J_var + eps)
# return negative cc.
return tf.reduce_mean(cc)
And I reproduce it in pytorch by
####################################
# PYTORCH
####################################
def ncc_torch( I, J):
# compute CC squares
I2 = I*I
J2 = J*J
IJ = I*J
# compute filters
batch_size, channels, _, _, _ = I.shape
sum_filt = torch.ones((batch_size, channels, *win)).float()
strides = [1] * (ndims)
# compute local sums via convolution
I_sum = F.conv3d(I, sum_filt, stride= strides, padding=1)
J_sum = F.conv3d(J, sum_filt, stride= strides, padding=1)
I2_sum = F.conv3d(I2, sum_filt, stride=strides, padding=1)
J2_sum = F.conv3d(J2, sum_filt, stride=strides, padding=1)
IJ_sum = F.conv3d(IJ, sum_filt, stride=strides, padding=1)
# compute cross correlation
win_size = np.prod(win)
u_I = I_sum / win_size
u_J = J_sum / win_size
cross = IJ_sum - u_J*I_sum - u_I*J_sum + u_I*u_J*win_size
I_var = I2_sum - 2 * u_I * I_sum + u_I*u_I*win_size
J_var = J2_sum - 2 * u_J * J_sum + u_J*u_J*win_size
cc = cross*cross / (I_var*J_var + eps)
return torch.mean(cc)
The unit test is
# Unit test
I = torch.rand(1,18,18,18,1) #BDHWC
J = torch.rand(1,18,18,18,1)
with tf.Session() as sess:
tf_ncc = ncc_tensorflow(I, J)
tf_result = sess.run(tf_ncc)
print ('tensorflow result ', tf_result)
I = I.permute(0,4,1,2,3) #BCDHW
J = J.permute(0,4,1,2,3) #BCDHW
print('Pytorch result' , ncc_torch(I,J).numpy())
``