Pointer network implementation for transformers

Is there any implementation of pointer network for transformers. I found an implementation in TF at https://github.com/xiongma/transformer-pointer-generator
I tried converting but not able to find a counter part of the following function in pytrorch.
The function is
loss = tf.map_fn(fn=lambda x: tf.gather_nd(x[1], x[0]), elems=(indices, final_dists), dtype=tf.float32)

It is at