How to implement tf.gather_nd with argument batch_dims in Pytorch?

I have been doing a project on image matching, so I need to find correspondences between 2 images. To get descriptors I will need a interpolate function. However, when I read about an equivalent function which is done in Tensorflow, I still don’t get how to implement tf.gather_nd(parmas, indices, barch_dims) in Pytorch. Especially when there is a argument: batch_dims. I have gone through stackoverflow and there is no perfect equivalence yet.

Here is the interpolate function in Tensorflow and I am trying to implement this in Pytorch:

inputs is a dense feature map[i] from a for loop of batch size, which means it is 3D[H, W, C](in Pytorch is [C, H, W])
pos is a set of random point coordinate shapes like [[I, j], [I, j],…,[I, j]], so it is 2D when it goes in interpolate(in Pytorch is [[i, i,…, i], [j, j,…, j]])

So when doing tf.gather_nd, so the parmas and indices arguments for gather_nd are 4D and 3D

and it then expands both of their dimensions when they get into this function
An example of using it :

pos = tf.ones((12, 2))  ## stands for a set of coordinates [[i, i,..., i], [j, j,..., j]]

inputs = tf.ones((4, 4, 128))  ## stands for [H, W, C] of dense feature map

outputs = interpolate(pos, inputs, batched=False)

print(outputs.get_shape())  #

>>>(12, 128)

interpolate function(tensorflow version)

def interpolate(pos, inputs, nd=True):

    pos = tf.expand_dims(pos, 0)
    inputs = tf.expand_dims(inputs, 0)

    h = tf.shape(inputs)[1]
    w = tf.shape(inputs)[2]

    i = pos[:, :, 0]
    j = pos[:, :, 1]

    i_top_left = tf.clip_by_value(tf.cast(tf.math.floor(i), tf.int32), 0, h - 1)
    j_top_left = tf.clip_by_value(tf.cast(tf.math.floor(j), tf.int32), 0, w - 1)

    i_top_right = tf.clip_by_value(tf.cast(tf.math.floor(i), tf.int32), 0, h - 1)
    j_top_right = tf.clip_by_value(tf.cast(tf.math.ceil(j), tf.int32), 0, w - 1)

    i_bottom_left = tf.clip_by_value(tf.cast(tf.math.ceil(i), tf.int32), 0, h - 1)
    j_bottom_left = tf.clip_by_value(tf.cast(tf.math.floor(j), tf.int32), 0, w - 1)

    i_bottom_right = tf.clip_by_value(tf.cast(tf.math.ceil(i), tf.int32), 0, h - 1)
    j_bottom_right = tf.clip_by_value(tf.cast(tf.math.ceil(j), tf.int32), 0, w - 1)

    dist_i_top_left = i - tf.cast(i_top_left, tf.float32)
    dist_j_top_left = j - tf.cast(j_top_left, tf.float32)
    w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
    w_top_right = (1 - dist_i_top_left) * dist_j_top_left
    w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
    w_bottom_right = dist_i_top_left * dist_j_top_left

    if nd:
        w_top_left = w_top_left[..., None]
        w_top_right = w_top_right[..., None]
        w_bottom_left = w_bottom_left[..., None]
        w_bottom_right = w_bottom_right[..., None]

    interpolated_val = (
        w_top_left * tf.gather_nd(inputs, tf.stack([i_top_left, j_top_left], axis=-1), batch_dims=1) +
        w_top_right * tf.gather_nd(inputs, tf.stack([i_top_right, j_top_right], axis=-1), batch_dims=1) +
        w_bottom_left * tf.gather_nd(inputs, tf.stack([i_bottom_left, j_bottom_left], axis=-1), batch_dims=1) +
        w_bottom_right * tf.gather_nd(inputs, tf.stack([i_bottom_right, j_bottom_right], axis=-1), batch_dims=1)
    )

    interpolated_val = tf.squeeze(interpolated_val, axis=0)
    return interpolated_val

This post might provide you a matching implementation for TensorFlow’s gather_nd.
Let me know, if that works for you.

Thank you for your reply! But it doesn’t seem to be compatible with my task

There is an example of using tf.gather_nd in the context

inputs shape: (1, 120, 160, 128)  ## [N, H, W, C], which stands for dense feature map
indices shape: (1, 512, 2)  ##[N, num_points, channel], which stands for a set of points [ [[iiiii], [jjjjj]] ]
output = tf.gather_nd(inputs, indices, batch_dims = 1)
print(output.get_shape())
>>> (1, 512, 128)

And just now I find out that in the above interpolate function, we could just simply delete the useless dimension expanding squeezing part at the beginning and the end to get it in a lower dimension case, which can also keep the result meaningful :

inputs shape: (120, 160, 128)
indices shape: (512, 2)
tf.gather_nd:  (512, 128)

Is there equivalence in Pytorch to get the equivalent result like this ? :

inputs :(1, 128, 120, 160)
indices:(1, 2, 512)
gather_nd_torch(inputs, indices)
>>>output:(1, 128, 512)

or this:

inputs :(128, 120, 160)
indices:(2, 512)
gather_nd_torch(inputs, indices)
>>>output:(128, 512)