Issue with using Lin's Concordance Correlation Coefficent as a Loss Function for Speech Emotion Recognition

Hi all. I am doing Speech Emotion Recognition on the IEMOCAP dataset. This a regression problem because I am predicting the dimensional scores. I am using Lin’s Concordance Correlation Coefficient as my loss function, in which case the loss is being maximized. The model I am finetuning is wav2vec2-large-robust. My issue is that the loss frequently becomes unstable, which makes the model predict nan values. I have tried standarizing my data, so that the dimensional scores have 0 mean and 1 variance, but that doesn’t seem to help. I can not tell if there is an issue with my code that is causing this problem. I am posting the training code below for inspection, kindly let me know if you spot an issue. Any help will be greatly appreciated. Thank you in advance!


def run_epoch(model, dataloader, optimizer, criterion, device, accumulation_steps, checkpoint):

    # if path to checkpoint exits, load from there
    if checkpoint:
        print()
        print(f"Training model from checkpoint {checkpoint}")
        checkpoint = torch.load(checkpoint)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    model.train()

    total_loss = 0
    total_samples = 0

    for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):

        # zero the parameter gradients
        if (i+1) % accumulation_steps == 0:
            optimizer.zero_grad()

        audio = batch["input_values"].to(device)
        label = batch["labels"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        output = model(input_values=audio, attention_mask=attention_mask)
        logits = output["logits"]

        # Split the logits and labels into separate tensors for each prediction
        logits_valence, logits_activation, logits_dominance = logits.split(1, dim=-1)
        labels_valence, labels_activation, labels_dominance = label.split(1, dim=-1)

        # Compute the loss separately for each prediction (negative to maximize the loss)
        loss_valence = -criterion(logits_valence, labels_valence)
        loss_activation = -criterion(logits_activation, labels_activation)
        loss_dominance = -criterion(logits_dominance, labels_dominance)

        # Combine the losses
        loss = (loss_valence + loss_activation + loss_dominance) / 3
        loss = loss / accumulation_steps
        loss.backward()

        if (((i+1) % accumulation_steps == 0) or (i + 1 == len(dataloader))):
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item()

        total_samples += len(label)

        # delete tensor no longer needed 
        del audio, label, output, attention_mask

        # empty cache 
        torch.cuda.empty_cache()

    return total_loss / total_samples



def run_evaluation(model, dataloader, criterion, device):

    model.eval()

    total_loss = 0
    total_samples = 0

    actual_labels = []
    predicted_labels = []

    with torch.no_grad():
        for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):

            audio = batch["input_values"].to(device)
            label = batch["labels"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            output = model(input_values=audio, attention_mask=attention_mask)
            logits = output["logits"]

            # Split the logits and labels into separate tensors for each prediction
            logits_valence, logits_activation, logits_dominance = logits.split(1, dim=-1)
            labels_valence, labels_activation, labels_dominance = label.split(1, dim=-1)

            # # Compute the loss separately for each prediction
            loss_valence = -criterion(logits_valence, labels_valence)
            loss_activation = -criterion(logits_activation, labels_activation)
            loss_dominance = -criterion(logits_dominance, labels_dominance)

            # Combine the losses
            loss = (loss_valence + loss_activation + loss_dominance) / 3
            total_loss += loss.item()

            total_samples += len(label)

            predicted_labels.extend(logits.detach().cpu().tolist())
            actual_labels.extend(label.cpu().tolist())

            del audio, label, output

            torch.cuda.empty_cache()

    return total_loss / total_samples, actual_labels, predicted_labels

def train_model(model, args):
    #total_steps = len(args.train_dataloader) * args.epochs
    #scheduler = get_linear_schedule_with_warmup(args.optimizer, num_warmup_steps=0, num_training_steps=total_steps)

    best_ccc_loss = float('-inf')  # Keep track of the best validation accuracy
    epochs = args.epochs
    
    if args.checkpoint:
        checkpoint = torch.load(args.checkpoint)
        best_ccc_loss = checkpoint['best_ccc_loss']
        # last_epoch = checkpoint['epoch']

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        # Run training and validation for each epoch
        train_loss = run_epoch(model, args.train_dataloader, args.optimizer, args.criterion, args.device, args.accumulation_steps, args.checkpoint)
        val_loss, actual_labels, predicted_labels = run_evaluation(model, args.eval_dataloader, args.criterion, args.device)
        # Print the metrics for this epoch
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Validation Loss: {val_loss:.4f}")

        # Save the model if it has the best validation accuracy so far
        if val_loss > best_ccc_loss:
            best_ccc_loss = val_loss
            name = f"epoch_{epoch+1}_val_loss_{best_ccc_loss:.4f}"
            torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(), 
                    'optimizer_state_dict': args.optimizer.state_dict(), 
                    'best_ccc_loss': best_ccc_loss,
                    }, f"{args.save_path}/model_{name}.pt")
            pred_df = pd.DataFrame(predicted_labels, columns=['valence', 'activation', 'dominance'])
            pred_df.to_csv(f'{args.save_path}/preds_{name}.csv', index=None)


    return predicted_labels