Thank you for the sample code. It helps to unsort the indices. However, how could you do something similar for the hidden states to unsort them in the initial input order ? I have some trouble because it’s a 3D tensor and ind is only 1D.
Thank you !
EDIT: A way without scatter would be How to properly unsort unpacked sequences? but I’m not sure that the gradient is well propagated and is not very efficient. Can someone confirm this ?