I have been doing a project on image matching, so I need to find correspondences between 2 images. To get descriptors I will need a interpolate function. However, when I read about an equivalent function which is done in Tensorflow, I still don’t get how to implement tf.gather_nd(parmas, indices, barch_dims) in Pytorch. Especially when there is a argument: batch_dims. I have gone through stackoverflow and there is no perfect equivalence yet.
Here is the interpolate function in Tensorflow and I am trying to implement this in Pytorch:
inputs is a dense feature map[i] from a for loop of batch size, which means it is 3D[H, W, C](in Pytorch is [C, H, W])
pos is a set of random point coordinate shapes like [[I, j], [I, j],…,[I, j]], so it is 2D when it goes in interpolate(in Pytorch is [[i, i,…, i], [j, j,…, j]])
So when doing tf.gather_nd, so the parmas and indices arguments for gather_nd are 4D and 3D
and it then expands both of their dimensions when they get into this function
An example of using it :
pos = tf.ones((12, 2)) ## stands for a set of coordinates [[i, i,..., i], [j, j,..., j]]
inputs = tf.ones((4, 4, 128)) ## stands for [H, W, C] of dense feature map
outputs = interpolate(pos, inputs, batched=False)
print(outputs.get_shape()) #
>>>(12, 128)
interpolate function(tensorflow version)
def interpolate(pos, inputs, nd=True):
pos = tf.expand_dims(pos, 0)
inputs = tf.expand_dims(inputs, 0)
h = tf.shape(inputs)[1]
w = tf.shape(inputs)[2]
i = pos[:, :, 0]
j = pos[:, :, 1]
i_top_left = tf.clip_by_value(tf.cast(tf.math.floor(i), tf.int32), 0, h - 1)
j_top_left = tf.clip_by_value(tf.cast(tf.math.floor(j), tf.int32), 0, w - 1)
i_top_right = tf.clip_by_value(tf.cast(tf.math.floor(i), tf.int32), 0, h - 1)
j_top_right = tf.clip_by_value(tf.cast(tf.math.ceil(j), tf.int32), 0, w - 1)
i_bottom_left = tf.clip_by_value(tf.cast(tf.math.ceil(i), tf.int32), 0, h - 1)
j_bottom_left = tf.clip_by_value(tf.cast(tf.math.floor(j), tf.int32), 0, w - 1)
i_bottom_right = tf.clip_by_value(tf.cast(tf.math.ceil(i), tf.int32), 0, h - 1)
j_bottom_right = tf.clip_by_value(tf.cast(tf.math.ceil(j), tf.int32), 0, w - 1)
dist_i_top_left = i - tf.cast(i_top_left, tf.float32)
dist_j_top_left = j - tf.cast(j_top_left, tf.float32)
w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
w_top_right = (1 - dist_i_top_left) * dist_j_top_left
w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
w_bottom_right = dist_i_top_left * dist_j_top_left
if nd:
w_top_left = w_top_left[..., None]
w_top_right = w_top_right[..., None]
w_bottom_left = w_bottom_left[..., None]
w_bottom_right = w_bottom_right[..., None]
interpolated_val = (
w_top_left * tf.gather_nd(inputs, tf.stack([i_top_left, j_top_left], axis=-1), batch_dims=1) +
w_top_right * tf.gather_nd(inputs, tf.stack([i_top_right, j_top_right], axis=-1), batch_dims=1) +
w_bottom_left * tf.gather_nd(inputs, tf.stack([i_bottom_left, j_bottom_left], axis=-1), batch_dims=1) +
w_bottom_right * tf.gather_nd(inputs, tf.stack([i_bottom_right, j_bottom_right], axis=-1), batch_dims=1)
)
interpolated_val = tf.squeeze(interpolated_val, axis=0)
return interpolated_val