Fixed my own problem. In case anyone else runs into the same issue, the problem was my custom dataset class. I mistakenly believed that python created copies of variables by default, e.g.
a = [5, 5, 5], b = a
would create two distinct variables, when in actuality b is simply a reference of a. This meant that my normalization code in the dataset class was editing the values of my data every time it was passed through the dataset class which eventually lead to inf values appearing after a set number of epochs.
My solution was to the ndarray.copy() command to avoid the original data being edited.