NaN loss function value

Hello all,

I have made the following very simple Neural network.

import torch
import numpy as np
import torch.nn as nn

from torch.nn.utils.parametrizations import spectral_norm
from torch.utils.data import Dataset, DataLoader

# Custom Dataset class
class DisturbanceDataset(Dataset):
    
    def __init__(self):
        # data load
        xy = np.loadtxt('data.csv', delimiter=",", dtype=np.float32)
        self.x = torch.from_numpy(xy[:,0:21])
        self.y = torch.from_numpy(xy[:, 21:])
        self.n_samples = xy.shape[0]
        
    def __getitem__(self, index):
        # dataset[0]
        return self.x[index, :], self.y[index, :]

    def __len__(self):
        # len(dataset)
        return self.n_samples

class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.relu = nn.ReLU()
        self.linear1 = spectral_norm(nn.Linear(21, 36))
        self.linear2 = nn.Linear(36, 9)

    def forward(self, x):
        out = self.linear1(x)
        out = self.relu(out)
        out = self.linear2(out)
        return out

# device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper parameters
input_size = 21
num_classes = 9
num_epochs = 1
batch_size = 32
learning_rate = 0.00001


train_dataset = DisturbanceDataset()
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

model = NeuralNet().to(device)

# Loss
criterion = nn.MSELoss()

# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

n_total_steps = len(train_dataloader)

for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(train_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # Forward Pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward Pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f'epoch {epoch+1} / {num_epochs}, step {i+1}/{n_total_steps}, loss = {loss.item():.4f}')

The output is this:

epoch 1 / 1, step 1/157, loss = 2440638842285636714496.0000
epoch 1 / 1, step 2/157, loss = inf
epoch 1 / 1, step 3/157, loss = nan
epoch 1 / 1, step 4/157, loss = nan
epoch 1 / 1, step 5/157, loss = nan
epoch 1 / 1, step 6/157, loss = nan
epoch 1 / 1, step 7/157, loss = nan
epoch 1 / 1, step 8/157, loss = nan
epoch 1 / 1, step 9/157, loss = nan
epoch 1 / 1, step 10/157, loss = nan
epoch 1 / 1, step 11/157, loss = nan
epoch 1 / 1, step 12/157, loss = nan
epoch 1 / 1, step 13/157, loss = nan
epoch 1 / 1, step 14/157, loss = nan
epoch 1 / 1, step 15/157, loss = nan
epoch 1 / 1, step 16/157, loss = nan
epoch 1 / 1, step 17/157, loss = nan
epoch 1 / 1, step 18/157, loss = nan
epoch 1 / 1, step 19/157, loss = nan
epoch 1 / 1, step 20/157, loss = nan
epoch 1 / 1, step 21/157, loss = nan
epoch 1 / 1, step 22/157, loss = nan
epoch 1 / 1, step 23/157, loss = nan
epoch 1 / 1, step 24/157, loss = nan
epoch 1 / 1, step 25/157, loss = nan
epoch 1 / 1, step 26/157, loss = nan
epoch 1 / 1, step 27/157, loss = nan
epoch 1 / 1, step 28/157, loss = nan
epoch 1 / 1, step 29/157, loss = nan
epoch 1 / 1, step 30/157, loss = nan
epoch 1 / 1, step 31/157, loss = nan
epoch 1 / 1, step 32/157, loss = nan
epoch 1 / 1, step 33/157, loss = nan
epoch 1 / 1, step 34/157, loss = nan
epoch 1 / 1, step 35/157, loss = nan
epoch 1 / 1, step 36/157, loss = nan
epoch 1 / 1, step 37/157, loss = nan
epoch 1 / 1, step 38/157, loss = nan
epoch 1 / 1, step 39/157, loss = nan
epoch 1 / 1, step 40/157, loss = nan
epoch 1 / 1, step 41/157, loss = nan
epoch 1 / 1, step 42/157, loss = nan
epoch 1 / 1, step 43/157, loss = nan
epoch 1 / 1, step 44/157, loss = nan
epoch 1 / 1, step 45/157, loss = nan
epoch 1 / 1, step 46/157, loss = nan
epoch 1 / 1, step 47/157, loss = nan
epoch 1 / 1, step 48/157, loss = nan
epoch 1 / 1, step 49/157, loss = nan
epoch 1 / 1, step 50/157, loss = nan
epoch 1 / 1, step 51/157, loss = nan
epoch 1 / 1, step 52/157, loss = nan
epoch 1 / 1, step 53/157, loss = nan
epoch 1 / 1, step 54/157, loss = nan
epoch 1 / 1, step 55/157, loss = nan
epoch 1 / 1, step 56/157, loss = nan
epoch 1 / 1, step 57/157, loss = nan
epoch 1 / 1, step 58/157, loss = nan
epoch 1 / 1, step 59/157, loss = nan
epoch 1 / 1, step 60/157, loss = nan
epoch 1 / 1, step 61/157, loss = nan
epoch 1 / 1, step 62/157, loss = nan
epoch 1 / 1, step 63/157, loss = nan
epoch 1 / 1, step 64/157, loss = nan
epoch 1 / 1, step 65/157, loss = nan
epoch 1 / 1, step 66/157, loss = nan
epoch 1 / 1, step 67/157, loss = nan
epoch 1 / 1, step 68/157, loss = nan
epoch 1 / 1, step 69/157, loss = nan
epoch 1 / 1, step 70/157, loss = nan
epoch 1 / 1, step 71/157, loss = nan
epoch 1 / 1, step 72/157, loss = nan
epoch 1 / 1, step 73/157, loss = nan
epoch 1 / 1, step 74/157, loss = nan
epoch 1 / 1, step 75/157, loss = nan
epoch 1 / 1, step 76/157, loss = nan
epoch 1 / 1, step 77/157, loss = nan
epoch 1 / 1, step 78/157, loss = nan
epoch 1 / 1, step 79/157, loss = nan
epoch 1 / 1, step 80/157, loss = nan
epoch 1 / 1, step 81/157, loss = nan
epoch 1 / 1, step 82/157, loss = nan
epoch 1 / 1, step 83/157, loss = nan
epoch 1 / 1, step 84/157, loss = nan
epoch 1 / 1, step 85/157, loss = nan
epoch 1 / 1, step 86/157, loss = nan
epoch 1 / 1, step 87/157, loss = nan
epoch 1 / 1, step 88/157, loss = nan
epoch 1 / 1, step 89/157, loss = nan
epoch 1 / 1, step 90/157, loss = nan
epoch 1 / 1, step 91/157, loss = nan
epoch 1 / 1, step 92/157, loss = nan
epoch 1 / 1, step 93/157, loss = nan
epoch 1 / 1, step 94/157, loss = nan
epoch 1 / 1, step 95/157, loss = nan
epoch 1 / 1, step 96/157, loss = nan
epoch 1 / 1, step 97/157, loss = nan
epoch 1 / 1, step 98/157, loss = nan
epoch 1 / 1, step 99/157, loss = nan
epoch 1 / 1, step 100/157, loss = nan
epoch 1 / 1, step 101/157, loss = nan
epoch 1 / 1, step 102/157, loss = nan
epoch 1 / 1, step 103/157, loss = nan
epoch 1 / 1, step 104/157, loss = nan
epoch 1 / 1, step 105/157, loss = nan
epoch 1 / 1, step 106/157, loss = nan
epoch 1 / 1, step 107/157, loss = nan
epoch 1 / 1, step 108/157, loss = nan
epoch 1 / 1, step 109/157, loss = nan
epoch 1 / 1, step 110/157, loss = nan
epoch 1 / 1, step 111/157, loss = nan
epoch 1 / 1, step 112/157, loss = nan
epoch 1 / 1, step 113/157, loss = nan
epoch 1 / 1, step 114/157, loss = nan
epoch 1 / 1, step 115/157, loss = nan
epoch 1 / 1, step 116/157, loss = nan
epoch 1 / 1, step 117/157, loss = nan
epoch 1 / 1, step 118/157, loss = nan
epoch 1 / 1, step 119/157, loss = nan
epoch 1 / 1, step 120/157, loss = nan
epoch 1 / 1, step 121/157, loss = nan
epoch 1 / 1, step 122/157, loss = nan
epoch 1 / 1, step 123/157, loss = nan
epoch 1 / 1, step 124/157, loss = nan
epoch 1 / 1, step 125/157, loss = nan
epoch 1 / 1, step 126/157, loss = nan
epoch 1 / 1, step 127/157, loss = nan
epoch 1 / 1, step 128/157, loss = nan
epoch 1 / 1, step 129/157, loss = nan
epoch 1 / 1, step 130/157, loss = nan
epoch 1 / 1, step 131/157, loss = nan
epoch 1 / 1, step 132/157, loss = nan
epoch 1 / 1, step 133/157, loss = nan
epoch 1 / 1, step 134/157, loss = nan
epoch 1 / 1, step 135/157, loss = nan
epoch 1 / 1, step 136/157, loss = nan
epoch 1 / 1, step 137/157, loss = nan
epoch 1 / 1, step 138/157, loss = nan
epoch 1 / 1, step 139/157, loss = nan
epoch 1 / 1, step 140/157, loss = nan
epoch 1 / 1, step 141/157, loss = nan
epoch 1 / 1, step 142/157, loss = nan
epoch 1 / 1, step 143/157, loss = nan
epoch 1 / 1, step 144/157, loss = nan
epoch 1 / 1, step 145/157, loss = nan
epoch 1 / 1, step 146/157, loss = nan
epoch 1 / 1, step 147/157, loss = nan
epoch 1 / 1, step 148/157, loss = nan
epoch 1 / 1, step 149/157, loss = nan
epoch 1 / 1, step 150/157, loss = nan
epoch 1 / 1, step 151/157, loss = nan
epoch 1 / 1, step 152/157, loss = nan
epoch 1 / 1, step 153/157, loss = nan
epoch 1 / 1, step 154/157, loss = nan
epoch 1 / 1, step 155/157, loss = nan
epoch 1 / 1, step 156/157, loss = nan
epoch 1 / 1, step 157/157, loss = nan

The data is 10000 samples with 21 input features. The number of outputs are 9. The data can be accessed through this link:
https://drive.google.com/file/d/1st_rDGBr4NdTn_5ByCToMQ64lvToX66O/view?usp=sharing

I am not sure where I can improve. I understand the NN can be deeper but I wanted to get this simple thing working first.

Thanks in advance.

Based on the initial loss, it seems the training diverges directly.
Using a random dataset seems to work:

train_dataset = torch.utils.data.TensorDataset(torch.randn(32*10, 21), torch.randn(32*10, 9))

Output:

epoch 1 / 1, step 1/10, loss = 1.0535
epoch 1 / 1, step 2/10, loss = 1.0202
epoch 1 / 1, step 3/10, loss = 0.9514
epoch 1 / 1, step 4/10, loss = 1.2088
epoch 1 / 1, step 5/10, loss = 0.9482
epoch 1 / 1, step 6/10, loss = 1.0481
epoch 1 / 1, step 7/10, loss = 0.9500
epoch 1 / 1, step 8/10, loss = 1.0517
epoch 1 / 1, step 9/10, loss = 1.0937
epoch 1 / 1, step 10/10, loss = 0.9892

(at least the loss doesn’t explode) so could you check the input data range and normalize it if needed?