I dug into it more and it seems that, for the Pytorch version, the outputs of the RNN almost all converge to 1/-1. Seems like we are saturating the tanh activation…
It’s strange that this is occurring with Pytorch and not Tensorflow. I was also able to overfit a sample train batch on TF but not on Pytorch.
Someone had a similar problem here
If you have any suggestions, please let me know! Thanks