Hey guys, I figured it out.
As someone new to using BERT encodings, I forgot to wrap my BERT generated embeddings with
with torch.no_grad()
I pray for any new NLP learners out there to find this post and not spend 12+ hours debugging. Good luck to you all!