Big error between tensorflow code and reproduce in pytorch

I have a tensorflow function. I want to convert it to pytorch. However, I got a big error between tensorflow and pytorch result. Could you please check my code and let me know what is the problem?

tensorflow result  0.23463342
Pytorch result 0.035501156

The colab is at Here

This is my target tensorflow code

eps = 1e-5
ndims = 3
win = [9] * ndims
####################################
#   TENSORFLOW
####################################
def ncc_tensorflow(I, J):
  # get convolution function
  conv_fn = getattr(tf.nn, 'conv%dd' % ndims)
  # compute CC squares
  I2 = I*I
  J2 = J*J
  IJ = I*J
  # compute filters
  sum_filt = tf.ones([*win, 1, 1])
  padding = 'SAME'
  strides = [1] * (ndims + 2)
  # compute local sums via convolution
  I_sum = conv_fn(I, sum_filt, strides, padding)
  J_sum = conv_fn(J, sum_filt, strides, padding)
  I2_sum = conv_fn(I2, sum_filt, strides, padding)
  J2_sum = conv_fn(J2, sum_filt, strides, padding)
  IJ_sum = conv_fn(IJ, sum_filt, strides, padding)

  # compute cross correlation
  win_size = np.prod(win)
  u_I = I_sum/win_size
  u_J = J_sum/win_size

  cross = IJ_sum - u_J*I_sum - u_I*J_sum + u_I*u_J*win_size
  I_var = I2_sum - 2 * u_I * I_sum + u_I*u_I*win_size
  J_var = J2_sum - 2 * u_J * J_sum + u_J*u_J*win_size

  cc = cross*cross / (I_var*J_var + eps)

  # return negative cc.
  return tf.reduce_mean(cc)

And I reproduce it in pytorch by

####################################
#   PYTORCH
#################################### 
def ncc_torch( I, J):        
  # compute CC squares
  I2 = I*I
  J2 = J*J
  IJ = I*J
  # compute filters
  batch_size, channels, _, _, _ = I.shape
  sum_filt = torch.ones((batch_size, channels, *win)).float()
  strides = [1] * (ndims)
  # compute local sums via convolution
  I_sum = F.conv3d(I, sum_filt, stride= strides, padding=1)
  J_sum = F.conv3d(J, sum_filt, stride= strides, padding=1)
  I2_sum = F.conv3d(I2, sum_filt, stride=strides, padding=1)
  J2_sum = F.conv3d(J2, sum_filt, stride=strides, padding=1)
  IJ_sum = F.conv3d(IJ, sum_filt, stride=strides, padding=1)

  # compute cross correlation
  win_size = np.prod(win)
  u_I = I_sum / win_size
  u_J = J_sum / win_size

  cross = IJ_sum - u_J*I_sum - u_I*J_sum + u_I*u_J*win_size
  I_var = I2_sum - 2 * u_I * I_sum + u_I*u_I*win_size
  J_var = J2_sum - 2 * u_J * J_sum + u_J*u_J*win_size

  cc = cross*cross / (I_var*J_var + eps)
  
  return torch.mean(cc)

The unit test is

# Unit test
I = torch.rand(1,18,18,18,1) #BDHWC
J = torch.rand(1,18,18,18,1)
with tf.Session() as sess:
  tf_ncc = ncc_tensorflow(I, J)
  tf_result = sess.run(tf_ncc)
  print ('tensorflow result ', tf_result)

I = I.permute(0,4,1,2,3) #BCDHW
J = J.permute(0,4,1,2,3) #BCDHW
print('Pytorch result' , ncc_torch(I,J).numpy())
  

``

Skimming through your code it looks like the padding in your PyTorch approach might be wrong.
If I’m not mistaken you are using a 3D ([9x9x9]) conv kernel. To get the same volumetric output as your input, you should use padding=4 for stride=1.

1 Like

Geat. It looks closer than previous using input size 2x81x81x81x1 and your kenel size suggestion

tensorflow result 0.05861372 
Pytorch result 0.029243657

However, the code also uses

win=[kernel size] * 3
win_size=np.prod(win)

So if I use kenelsize is 4, then it will be change the result on np.prod(win). Do you think we still need to keep same value win_size for both tf and pytorch, ie 999 and just modify size in the conv3d?

I’m not sure I understand the last question properly, but I would suggest to stick as close as possible to the reference implementation in order to get the same results.

filter shape is wrong, should have 1st dimension of size 1

@SimonW: Thanks. Could you show me the place that wrong? I cannot catch it

I think @SimonW is mentioning that the above snippet would be

sum_filt = torch.ones((1, channels, *win)).float()

Thanks . That is my typo. I fixed it but the result is same because my example bstch size is 1. I guess main problem is padding type

1 Like

@ptrblck it worked. Sorry. I mistake your comment. The padding shoud be 4 instead of kernel 4

1 Like