Found a possible solution here : How to do the tf.gather_nd in pytorch? - #18 by Cogito2012. On applying this, I do get an output of desired shape. However, this error keeps triggered for the following operations ( RuntimeError: CUDA error: device-side assert triggered)
Edit : This works! Turns out there was something wrong with my indices.