NaN loss after some batches when using resnet/ViT as feature extractor

Hi everyone,

i’m trying to use ViT/Resnet as a part of a hybrid model that combines image and numerical data. My idea is to use the ‘VisionBackbone’ as a feature extractor and then concatenate the last layer of ViT/Resnet with my numerical variables before using some additional dense/fc layers.

Running into the same problem as AssertionError: Expected (batch_size, seq_length, hidden_dim) got torch.Size([1, 768, 24, 31]) for ViT i decided to use the proposed solution (using: from torchvision.models.feature_extraction import create_feature_extractor). Using this my trainings randomly run into NaN values for the loss in every epoch resulting in a faulty training.

My code:

import os
import time
import json
import copy
import torch
import argparse
import torchvision
import numpy as np
import pandas as pd
from tqdm import tqdm
from torchmetrics import F1Score
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.functional import precision_recall
from torchvision.models.feature_extraction import create_feature_extractor


class hybrid(nn.Module):
    def __init__(self,VisionBackbone,BackboneOutFeatures):
        super(hybrid, self).__init__()
        self.VisionBackbone = VisionBackbone
        #46 numerical values
        self.NumInput = nn.Linear(46+BackboneOutFeatures,100)
        self.fc1 = nn.Linear(100, 3)
    def forward(self, x_img,x_num):
        x_img = self.VisionBackbone(x_img)
        if type(x_img) is dict:
            x_img = x_img[list(x_img.keys())[0]]
        x_img = torch.squeeze(x_img)
        x_img = torch.flatten(x_img, 1)
        x = torch.cat((x_num,x_img),-1)
        x = F.relu(self.Num(x))
        x = self.fc1(x)
        return F.log_softmax(x, dim=1)

class ImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir,num_file1,num_file2, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)

        f = open(num_file1)
        num_file1_json = json.load(f)
        columns_num=num_file1_json['index']['var1']
        self.num_file1_df = pd.DataFrame(columns=columns_num)
        for key in num_file1_json.keys():
            self.num_file1_df.loc[key] = num_file1_json[key]['proportions']
        self.num_file1_df.fillna(value=0)

        self.num_file2_df= self.num_file2_df.set_index('filename')
        columns_to_norm = ['var1','var2']
        self.num_file2_df[columns_to_norm]=(self.num_file2_df[columns_to_norm]-self.num_file2_df[columns_to_norm].mean())/self.num_file2_df[columns_to_norm].std()
        self.num_file2_df.fillna(value=0)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = torchvision.io.read_image(img_path)/255.0
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        num1 = self.num1_df.loc[self.img_labels.iloc[idx, 0]].values
        num2 = self.num2_df.loc[self.img_labels.iloc[idx, 0].split('_')[0]].values
        return image,torch.cat((torch.Tensor(num1),torch.Tensor(num2)),axis=0), label

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    f1 = F1Score(num_classes=3,average='none').to(device)
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 10000000

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            phase_outputs = []
            phase_targets = []
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs_img,inputs_num, labels in tqdm(dataloaders[phase],desc=phase):
                inputs_img = inputs_img.to(device)
                inputs_num= inputs_num.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs_img,inputs_num)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs_img.size(0)
                running_corrects += torch.sum(preds == labels.data)

                phase_outputs.append(preds)
                phase_targets.append(labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            phase_outputs = torch.cat(phase_outputs)
            phase_targets = torch.cat(phase_targets)
            epoch_f1 = f1(phase_outputs, phase_targets)
            epoch_precision,epoch_recall = precision_recall(phase_outputs, phase_targets,average='none', num_classes=3)
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} F1: {torch.nanmean(epoch_f1):.4f} Recall: {torch.nanmean(epoch_recall):.4f} Precision: {torch.nanmean(epoch_precision):.4f}' )

            writer.add_scalar(f"{phase}/Loss", epoch_loss, epoch)
            writer.add_scalar(f"{phase}/Accuracy", epoch_acc, epoch)
            writer.add_scalar(f"{phase}/F1", torch.nanmean(epoch_f1), epoch)
            writer.add_scalar(f"{phase}/precision", torch.nanmean(epoch_precision), epoch)
            writer.add_scalar(f"{phase}/recall", torch.nanmean(epoch_recall), epoch)
            writer.add_scalar(f"{phase}/learning_rate", scheduler.get_last_lr()[0], epoch)

            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    print("saving latest weights")
    torch.save(model,save_path+"latest_model.pt")
    time_elapsed = time.time() - since
    writer.add_scalar("train/time", time_elapsed, 15)
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Loss: {best_loss:4f}')

    # load best model weights
    print("saving best val weights")
    model.load_state_dict(best_model_wts)
    torch.save(model,save_path+"best_val_model.pt")
    return model


parser = argparse.ArgumentParser(description='Program to the predictions on all images')
parser.add_argument("--model_type", help="model to train",type=str)
parser.add_argument("--training_name", help="name of training",type=str)
parser.add_argument("--num_epochs", help="number of training epochs, default value of 15",type=int,default=15)


args = parser.parse_args()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

save_path=f"models/{args.model_type}/{args.training_name}/"
writer = SummaryWriter(log_dir=save_path)

model_ft=None
transform=None
if args.model_type=='Resnet':
    VisionBackbone = getattr(torchvision.models,"resnet18")(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
    VisionBackbone = create_feature_extractor(VisionBackbone, return_nodes=['avgpool'])
    model_ft = hybrid(VisionBackbone,512)
    transform=transforms.Compose([
        transforms.Resize((224,224)),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
if args.model_type=='ViT':
    VisionBackbone = getattr(torchvision.models,"vit_b_16")(weights=torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1)
    VisionBackbone = create_feature_extractor(VisionBackbone, return_nodes=['getitem_5'])
    model_ft = hybrid(VisionBackbone,768)      
    transform=transforms.Compose([
        transforms.Resize((224,224)),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

model_ft = model_ft.to(device)

training_data = ImageDataset(
    annotations_file='data/train.csv',
    img_dir='data/images/',
    num_file1='path',
    num_file2='path',
    transform=transform
)
val_data = ImageDataset(
    annotations_file='data/val.csv',
    img_dir='data/images',
    num_file1='path',
    num_file2='path',
    transform=transform
)

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=64, shuffle=True)
dataloaders={"train":train_dataloader,"val":val_dataloader}
criterion = torch.nn.CrossEntropyLoss()

optimizer_ft = torch.optim.AdamW(model_ft.parameters(), lr=2e-4,betas=(0.9, 0.999), eps=1e-08,)

exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,step_size=1,gamma=0.925)
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=args.num_epochs)
writer.close()

some things i’d like to mention:

  1. the training code worked fine before changing my model to a hybrid one (just using ViT or Resnet outside of my ‘hybrid’ class)
  2. i’ve changed some of the variable names for the sake of this example
  3. i’m not super familiar with PT or PT FX.

Were you able to resolve the issue ? I am facing a similar error now.