Nan training and testing loss

When trying to use a LSTM model for regression, I find that I am getting NaN values when I print out training and testing loss. The DataFrame I pass into the model has no NaN values, so I believe it is an issue with my model or my training/testing loop functions. Any help in this regard would be greatly appreciated.

Model Class:

class PKLSTM(nn.Module):
    def __init__(self, input_size, hidden_units, num_layers):
        super(PKLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_units = hidden_units
        self.num_layers = num_layers 

        self.lstm = nn.LSTM(

        self.linear = nn.Linear(hidden_units, 1)
    def forward(self, x):
        batch_size = x.shape[0]
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_units).requires_grad_()
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_units).requires_grad_()

        _, (hn, _) = self.lstm(x, (h0, c0))
        out = self.linear(hn[0]).flatten()

        return out

Training and testing loop:

def train_model(data_loader, model, loss_function, optimizer):
    num_batches = len(data_loader)
    total_loss = 0

    for X, y in data_loader:
        output = model(X)
        loss = loss_function(output, y)

        total_loss += loss.item()

    avg_loss = total_loss / num_batches
    print(f"Train loss: {avg_loss}")

def test_model(data_loader, model, loss_function):

    num_batches = len(data_loader)
    # print('Num of batches', num_batches)
    total_loss = 0

    with torch.no_grad():
        for X, y in data_loader:
            # print('input: ', X)
            output = model(X)
            # print('output: ', output)
            total_loss += loss_function(output, y).item()

    avg_loss = total_loss / num_batches
    print(f"Test loss: {avg_loss}")

if your learning rate is too big model will diverge which eventually would lead to an Inf/NaN.
Aaand the second option is there is a NaN in the input data. Which may not appear in the data frame but after preprocessing?

Thank you. Some of my code in my dataset class was generating NaNs. Appreciate all the help!