Problem about predict nan after few batch

I am a beginner about pytorch.

I try to use pre-train model to do classification problem.

But I found my loss and predict nan both after the first epoch.

So I step by step to look what happen in the process,

  1. I check my data have nan or not, the data doesn’t have nan.

  2. Reduce the learning rate smaller, 1e-10, but the loss still nan

  3. I write the break switch when I get nan predict, here I found something.
    In the last batch, everything is OK, predict and loss are not nan.
    In the break batch, the input x is not nan, parameters are not nan,
    but the predict is nan, and then lead to everything all nan in the future.

How to understand what happen in the training process?

Is something wrong in my data? or I can modified my model that fix it?

Here is my code blew:

Main

##
##  Load
from Requirement import *
os.makedirs("Holdout/", exist_ok = True)
##
##  Data table
dataTable = ChestDataTable()
##
##  Holdout
SetSeed()
dataTrainTable, dataValidTable = holdout(dataTable, test_size = 0.05, stratify = dataTable['Target'])

##
##  Augmentation
dataTrainAugTable = AugChestTable(dataTrainTable, 3500)
##
##  Data set
ptrTrainAugSet = ChestDataSet(dataTrainAugTable)
ptrValidSet         = ChestDataSet(dataValidTable)
##
##  Parameter set
listBatch = [2]
listEpoch = [5]
listOptimizer  = ["Adam"]
listEta        = [1e-6]
dictParameterSet = {"Batch": listBatch, "Epoch": listEpoch, "Optimizer": listOptimizer, "Eta": listEta}
dataParameterSet = DataFrame(ParameterGrid(dictParameterSet)).reset_index()
dataParameterSet.to_csv("Holdout/ParameterSet.csv", index = False)
##
##  In parameter loop
for i, p in dataParameterSet.iterrows():
    ##
    ##  Model
    funModel = DenseBaseNet()
    ##
    ##  Set parameter
    intBatch         = p['Batch']
    intEpoch        = p['Epoch']
    strOptimizer = p['Optimizer']
    floatEta           = p['Eta']
    funCriterion = CrossEntropyLoss()
    if(strOptimizer=="Adam"):
        ptrOptimizer = optim.Adam(funModel.parameters(), lr=floatEta)
    ##
    ##  Loader
    ptrTrainAugLoader = DataLoader(dataset=ptrTrainAugSet, batch_size=intBatch, shuffle=True, num_workers=0)
    ptrValidLoader         = DataLoader(dataset=ptrValidSet, batch_size=intBatch, shuffle=False, num_workers=0)
    ##
    ##  Initial in the  epoch
    # listLoss = []
    dictHistory = {"TrainLoss":[], "ValidLoss":[], "ValidAccuracy":[]}
    listValidResult = []
    for e in range(intEpoch):
        bMax  = ptrTrainAugLoader.dataset.len
        nSum  = 0
        eLoss = 0.0
        for _, b in enumerate(ptrTrainAugLoader, 0):
            ##
            ##  Get x and y
            x, y = b
            n = x.shape[0]
            ##
            ##  Zero the parameter gradients
            ptrOptimizer.zero_grad()
            ##
            ##  Forward + Backward + Optimize
            tenOutput = funModel.cuda()(x.cuda())
            bLoss = funCriterion(tenOutput, y.cuda())
            bLoss.backward()
            ptrOptimizer.step()
            ##
            ##  Update
            eLoss += (bLoss.item() * n)
            nSum = nSum + n
            if(nSum==bMax):
                ##
                ##  Check valid
                with torch.no_grad():
                    tenValidOutput = funModel.cpu()(ptrValidLoader.dataset.x)
                    floatValidLoss = funCriterion(tenValidOutput, ptrValidLoader.dataset.y).item()
                    _, tenValidPrediction = torch.max(tenValidOutput.data, 1)
                    floatAccuracy  = accuracy_score(ptrValidLoader.dataset.y, tenValidPrediction)
                    ##
                    ##  Update
                    eLoss = eLoss / bMax
                    # listLoss = listLoss + [eLoss]
                    dictHistory['ValidLoss'] = dictHistory['ValidLoss'] + [floatValidLoss]
                    dictHistory['ValidAccuracy'] = dictHistory['ValidAccuracy'] + [floatAccuracy]
                    dictHistory['TrainLoss'] = dictHistory['TrainLoss'] + [eLoss]
                    print("Epoch: %s" % e)
                    print("Train loss: %s" % eLoss, "Valid loss: %s " % floatValidLoss)
                    print("Valid accuracy: %s " % floatAccuracy)
    ##
    ##  Summary on valid
    _, tenValidPrediction = torch.max(tenValidOutput.data, 1)
    dataValidPrediction = DataFrame({"Prediction":array(tenValidPrediction)})
    dataValidOutput = DataFrame(array(tenValidOutput)).reset_index(drop=True)
    dataValidOutput.columns = ["Prob-" + str(i) for i in dataValidOutput.columns]
    dataValidTable = dataValidTable.reset_index(drop=True)
    dataValidResult = pandas.concat([dataValidTable, dataValidOutput, dataValidPrediction], axis = 1)
    dataValidResult.to_csv("Holdout/" + str(i) + "-ValidResult.csv", index=False)
    ##
    ##  Save model and history
    dataHistory = DataFrame(dictHistory)
    dataHistory.to_csv("Holdout/" + str(i) + "-History.csv", index=False)
    torch.save(funModel, "Holdout/" + str(i) + '-Model.h5')



dataParameterSet = pandas.read_csv("Holdout/ParameterSet.csv")
dataParameterSet['ValidAccuracy'] = None
for i, p in dataParameterSet.iterrows():
    funModel = torch.load("Holdout/" + str(i) + "-Model.h5")
    ##
    ##  Evaluate valid
    with torch.no_grad():
        tenValidOutput = funModel.cpu()(ptrValidLoader.dataset.x)
        floatValidLoss = funCriterion(tenValidOutput, ptrValidLoader.dataset.y).item()
        _, tenValidPrediction = torch.max(tenValidOutput.data, 1)
        floatAccuracy  = accuracy_score(ptrValidLoader.dataset.y, tenValidPrediction)
        ##
        ##  Accuracy
        dataParameterSet.at[i,'ValidAccuracy'] = floatAccuracy

Requirement

import pandas
import os
import shutil
import PIL.Image as pil
import xml.etree.ElementTree as et
import numpy
import xlrd
import cv2 as cv
import numpy
from sklearn.model_selection import ParameterGrid
import torch
import pandas
import numpy
import random
import PIL
import os
import sklearn
from os import listdir
from pandas import DataFrame
from numpy import array
from torch import cuda
from torch import FloatTensor, DoubleTensor
from torch import LongTensor
from torch import from_numpy
from torch.utils.data import Dataset, DataLoader
from PIL import Image as pil
from torch import optim
from torch import nn
from sklearn.model_selection import train_test_split as holdout
from sklearn.model_selection import StratifiedKFold as fold
from sklearn.metrics import roc_auc_score, accuracy_score
from torch import nn
from torch.nn import Linear, Softmax, functional, Sequential, Module, CrossEntropyLoss
from torch.nn.functional import relu
import torchvision
from torchvision import models
from torchvision import transforms
##
##  Crop image
def CropChestImage(image, e = 5):
    gray = cv.cvtColor(image, cv.COLOR_BGR2GRAY)
    binary = numpy.where(gray>250,255, 0).astype("uint8")
    binary[0:int(binary.shape[1]/6), 0 :int(binary.shape[0]*2/5)] = 0
    contour, _ = cv.findContours(binary, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
    for i in contour:
        i = numpy.squeeze(i, axis = 1)
        Xmax = max(i[:,0])
        Xmin  = min(i[:,0])
        Ymax = max(i[:,1])
        Ymin = min(i[:,1])
        crop = image[Ymin+e:Ymax-e, Xmin+e:Xmax-e]
        stop = (crop.shape[0]>400) & ((crop.shape[1]>400))
        if(stop):
            break
    return(crop)
##
##  Data table
def ChestDataTable():
    listId = listdir("./DataSet/Clean/")
    listLabel = [str.split(i, "_")[0] for i in listId]
    dataTable = DataFrame({"Id":listId, "Label": listLabel})
    dataTable['Target'] = dataTable['Label'].replace({"n0":0, "n1":1, "n2":2, "n3":3})
    return(dataTable)
##
##  Data set
class ChestDataSet(Dataset):
    def __init__(self, data):
        funTransform = transforms.Compose([transforms.Resize((224, 224)),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.RandomVerticalFlip(),
                                         transforms.RandomRotation(degrees = 360),
                                         transforms.ToTensor()])
        ##
        ##  Get table
        dataTable = data
        ##
        ##  Get image
        listImage = []
        for i in dataTable['Id']:
            iImage = pil.open("DataSet/Clean/" + i).resize((224,224))
            iImage = funTransform(iImage)
            iImage = array(iImage)
            listImage.append(iImage)
            pass
        ##
        ##  Get torch format
        tenImage  = from_numpy(array(listImage)).type(DoubleTensor).float()
        tenTarget = from_numpy(array(dataTable['Target'])).type(LongTensor)
        ##
        ##  Custom
        self.len = tenTarget.shape[0]
        self.x = tenImage
        self.y = tenTarget
    def __getitem__(self, index):
        return self.x[index], self.y[index]
    def __len__(self):
        return self.len
# ##
# ##
# def MakeFold(data, target, k=3):
#     dataTable = data
#     dataTable['Fold'] = None
#     funFold = fold(n_splits=k)
#     for i, (_, index) in enumerate(funFold.split(dataTable, dataTable['Target'])):
#         dataTable.at[index,'Fold'] = i + 1
#     return(dataTable)

def AugChestTable(data, each = 5000):
    listAugTable = []
    dataTable = data
    setTarget = set(dataTable['Target'])
    for t in setTarget:
        try:
            dataAug = dataTable.loc[dataTable['Target']==t].sample(each, replace = False)
        except:
            dataAug = dataTable.loc[dataTable['Target']==t].sample(each, replace = True)
        listAugTable.append(dataAug)
    dataAugTable = pandas.concat(listAugTable)
    return(dataAugTable)
##
##  Res base model
class ResBaseNet(Module):
    def __init__(self):
        super(ResBaseNet, self).__init__()
        self.funFeatExtra = Sequential(*[i for i in list(models.resnet34().children())[:-1]])
        self.funOutputLayer = Linear(512, 4)
        self.funSoftmax     = Softmax(dim=1)
    def forward(self, x):
        x = self.funFeatExtra(x)
        x = x.view(-1, x.shape[1])
        x = self.funOutputLayer(x)
        x = self.funSoftmax(x)
        return x
##
##  Dense base model
class DenseBaseNet(Module):
    def __init__(self):
        super(DenseBaseNet, self).__init__()
        self.funFeatExtra = Sequential(*[i for i in list(models.densenet121().children())[:-1]])
        self.funOutputLayer = Linear(50176, 4)
        self.funSoftmax     = Softmax(dim=1)
    def forward(self, x):
        x = self.funFeatExtra(x)
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = self.funOutputLayer(x)
        x = self.funSoftmax(x)
        return x
##
##  Check gpu information
def CheckGpuInf():
    intDeviceCount = cuda.current_device()
    print("Current device: %s" % intDeviceCount)
    print("Device count: %s" % cuda.device_count())
    print("Device name: %s" % cuda.get_device_name(intDeviceCount))
    return
##
##  Set seed
def SetSeed():
    ##
    ##   Fix result
    random.seed(1)
    numpy.random.seed(123)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return

I use tensorflow.keras reslove this problem, so maybe change tenorflow is better.

nn.CrossEntropyLoss expects logits while your model seems to apply a softmax on the output.
While this shouldn’t be responsible for the NaNs, you still should remove it as this might slow down your training.
To debug further, you could use Anomaly Detection and try to see where these NaN values got produced.

Based on your description it seems like you are seeing NaN values in the output after one epoch. Is that correct?
Just for the sake of debugging, could you set drop_last=True in your DataLoader and run the code again?

The output layer return NaN after some mini-batch.
In that iteration, there are no NaN in the gradient.
That output let all NaN after all.
Thanks, I will try your suggestion.

@hzyu810225, have you find the cause?

I’m having the same problem in multiclass classification problem in my vanilla baseline model

# Create model class

class FER2013_model_v0(nn.Module):
  def __init__(
      self,
      input_shape: int,
      hidden_units: int,
      output_shape: int
  ):
    super().__init__()
    self.layer_stack = nn.Sequential(
        nn.Flatten(),
        nn.Linear(in_features=input_shape,
                  out_features=hidden_units,
                  bias=True), # True is default - here just for learning purposes
        nn.Linear(in_features=hidden_units,
                  out_features=output_shape,
                  bias=True)
    )

  def forward(self, x):
    return self.layer_stack(x)

This is my training loop:

# Set the seed
torch.manual_seed(44)

# Set hyperparameters
EPOCHS = 3

# Create training and test loop

for epoch in tqdm(range(EPOCHS)):
  print(f'Starting epoch: {epoch} \n ============')

  # Training
  train_loss = 0

  # Loop through training batches
  for batch, (images, labels) in enumerate(train_dataloader):
    # Set model to training mode
    model_0.train()
    # 1. Forward pass
    predictions = model_0(images)

    # 2. Calculate batch loss
    loss = loss_fn(predictions, labels)
    train_loss += loss

    # 3. Optimizer zero grad
    optimizer.zero_grad()

    # 4. Loss backward
    loss.backward()

    # 5. Optimizer step
    optimizer.step()

    # Print out what's happening
    if batch % 100 == 0:
      print(predictions)
      print(f'Completed training of {batch*len(images)} / {len(train_dataloader.dataset)} samples')

  # Divide total train loss by length of train dataloader
  train_loss /= len(train_dataloader)

  # Anomaly detection

  with torch.autograd.detect_anomaly():
    input=torch.rand(size=(32, 48, 48), requires_grad=True)
    output=model_0(input)
    output.backward()
  
  # Testing
  test_loss, test_acc, test_bal_acc = 0, 0, 0

  # Set model to evaluation mode
  model_0.eval()

  with torch.inference_mode():
    # Loop through validation batches
    for images, labels in validation_dataloader:
      # Forward pass
      predictions = model_0(images)
      print(predictions)

      # Calculate loss
      test_loss += loss_fn(predictions, labels)

      # Compute metrics for the batch
      print(predictions)
      test_acc += accuracy_obj(predictions.argmax(dim=1), labels)
      test_bal_acc += balanced_accuracy_score(y_pred = predictions.argmax(dim=1) , y_true = labels)

    # Compute test_loss for the epoch
    test_loss /= len(validation_dataloader)

    # Compute test metrics for the epoch
    test_acc /= len(validation_dataloader)
    test_bal_acc /= len(validation_dataloader)

  # Print epoch summary
  plt.imshow(model_0.state_dict()['layer_stack.1.weight'].reshape((10, 48, 48))[0], cmap='Greys_r')
  print('==============================')
  print(f'\nTrain loss at epoch {epoch}: {train_loss:.4f} | Test loss: {test_loss:.4f}, Test acc: {test_acc:.4f}, Test balanced acc: {test_bal_acc:.4f}')

loss_fn is categorical crossentropy with weights (FER2013 dataset os fairly imbalanced) and optimizer is SGD with lr=0.001.

Please help!