Depthwise convolution for gradient filters

I am trying to make a Sobel-type filter by adapting the follwoing Tensorflow code into Pytorch.

def get_gradient_filters():
    np_grad_x = np.asarray([[-3,0,3], [-10,0,10], [-3,0,3]], dtype=np.float32).reshape((3, 3, 1, 1))
    np_grad_x /= np.sum(np.abs(np_grad_x), keepdims=True)
    tf_grad = tf.constant((np.concatenate([np_grad_x, np_grad_x.transpose((1, 0, 2, 3))], axis=-1)))
    return tf_grad

tf_grad_filter = get_gradient_filters()
tf_grad_filter = tf.tile(tf_grad_filter, (1, 1, 3, 1))

image=skimage.data.astronaut().astype(np.float32)/255.

tf_f=tf.expand_dims(tf.constant(image),0)
tf_f_grad = tf.nn.depthwise_conv2d(tf_f, tf_grad_filter, [1, 1, 1, 1], "SAME")/(1./32)
tf_f_grad=tf.reshape(tf_f_grad,(1,512,512,3,2))

Here is my attempt which does not give the same result.

def gradient_filter(C):
    np_grad_x = np.asarray([[-3, 0, 3], [-10, 0, 10], [-3, 0, 3]], dtype=np.float32).reshape((3, 3, 1, 1))
    np_grad_x /= np.sum(np.abs(np_grad_x), keepdims=True)
    np_grad = np.concatenate([np_grad_x, np_grad_x.transpose((1, 0, 2, 3))], axis=-1)
    np_grad = np.tile(np_grad, (1, 1, C, 1))
    torch_grad = torch.Tensor(np_grad).reshape(1, 3, 3, 2 * C).permute(3, 0, 1, 2)
    filter = torch.nn.Conv2d(in_channels=C, out_channels=2 * C, kernel_size=(3, 3), padding=(1, 1), padding_mode='zeros',
                       groups=C, bias=False)
    filter.weight.data = torch_grad
    return filter

filter = gradient_filter(3)
def img_gradient(torch_f, filter):
    M, N, _ = torch_f.shape
    return filter(torch_f.unsqueeze(dim=0).permute(0, 3, 1, 2)) /(1./max(M,N))

torch_f=torch.Tensor(image)
torch_f_grad=img_gradient(torch_f,filter)
torch_f_grad=torch_f_grad.permute(0,2,3,1).reshape(1,512,512,3,2)

I also tried some minor variations (first permute then reshape) but it is still not the same. Any ideas?

How large is the max. absolute difference? If it’s in the range of approx. 1e-6 you might be running into the floating point precision for float32.

The differences are not that small (some values in outputs are even of size 1e2) but for some examples of images it did seem as if every term of one output is corresponding the term in other output with some precision error. What do you recommend in that case?

Edit: I tried with float64 and the difference is indeed now at most 1e-5. Is it possible to even decrease this?

Did you see a difference of 1e2 for some output values and others were much smaller?
If so, this sounds like an overflow issue, which I wouldn’t expect to happen in float32.

I can not reproduce the example with difference 1e2 right now and I think it was caused by something else at that moment. It seems that my code might be correct after all and that difference is due to floating point precision.